mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-20 07:54:14 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef75a89fdb | ||
|
|
d8b5cdc4fe | ||
|
|
dea9ba27cb | ||
|
|
c6d1a00aa7 | ||
|
|
424c579455 | ||
|
|
e9f9483464 | ||
|
|
41c5e02f42 | ||
|
|
2e1c9cd814 | ||
|
|
190c4838bd | ||
|
|
e7c2cf1356 | ||
|
|
1257491047 | ||
|
|
083e18b11c | ||
|
|
3d94e967a1 | ||
|
|
7feb0a1005 | ||
|
|
0a8026e768 | ||
|
|
5ceed62421 | ||
|
|
7ca5991d2b | ||
|
|
b3e3060f4e | ||
|
|
37adc9c6ba |
40
.github/workflows/build.yml
vendored
40
.github/workflows/build.yml
vendored
@@ -547,6 +547,46 @@ jobs:
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 3600
|
||||
|
||||
ubuntu-24-wasm-webgpu:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-latest-wasm-webgpu
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Install Emscripten
|
||||
run: |
|
||||
git clone https://github.com/emscripten-core/emsdk.git
|
||||
cd emsdk
|
||||
./emsdk install latest
|
||||
./emsdk activate latest
|
||||
|
||||
- name: Fetch emdawnwebgpu
|
||||
run: |
|
||||
DAWN_TAG="v20251027.212519"
|
||||
EMDAWN_PKG="emdawnwebgpu_pkg-${DAWN_TAG}.zip"
|
||||
echo "Downloading ${EMDAWN_PKG}"
|
||||
curl -L -o emdawn.zip \
|
||||
"https://github.com/google/dawn/releases/download/${DAWN_TAG}/${EMDAWN_PKG}"
|
||||
unzip emdawn.zip
|
||||
|
||||
- name: Build WASM WebGPU
|
||||
run: |
|
||||
source emsdk/emsdk_env.sh
|
||||
emcmake cmake -B build-wasm \
|
||||
-DGGML_WEBGPU=ON \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DEMDAWNWEBGPU_DIR=emdawnwebgpu_pkg
|
||||
|
||||
cmake --build build-wasm --target test-backend-ops -j $(nproc)
|
||||
|
||||
ubuntu-22-cmake-hip:
|
||||
runs-on: ubuntu-22.04
|
||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||
|
||||
71
.github/workflows/release.yml
vendored
71
.github/workflows/release.yml
vendored
@@ -728,58 +728,6 @@ jobs:
|
||||
path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
|
||||
name: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
|
||||
|
||||
openEuler-cann:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [x86, aarch64]
|
||||
chip_type: ['910b', '310p']
|
||||
build: ['Release']
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
yum update -y
|
||||
yum install -y git gcc gcc-c++ make cmake libcurl-devel
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
|
||||
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||
-DGGML_CANN=on \
|
||||
-DSOC_TYPE=ascend${{ matrix.chip_type }}
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Pack artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/*
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts (zip)
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
|
||||
- name: Upload artifacts (tar)
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz
|
||||
|
||||
release:
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
|
||||
@@ -801,7 +749,6 @@ jobs:
|
||||
- macOS-arm64
|
||||
- macOS-x64
|
||||
- ios-xcode-build
|
||||
- openEuler-cann
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -869,6 +816,12 @@ jobs:
|
||||
> [!WARNING]
|
||||
> **Release Format Update**: Linux releases will soon use .tar.gz archives instead of .zip. Please make the necessary changes to your deployment scripts.
|
||||
|
||||
<details open>
|
||||
|
||||
${{ github.event.head_commit.message }}
|
||||
|
||||
</details>
|
||||
|
||||
**macOS/iOS:**
|
||||
- [macOS Apple Silicon (arm64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz)
|
||||
- [macOS Intel (x64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz)
|
||||
@@ -887,18 +840,6 @@ jobs:
|
||||
- [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip)
|
||||
- [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip)
|
||||
|
||||
**openEuler:**
|
||||
- [openEuler x86 (310p)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-310p-openEuler-x86.tar.gz)
|
||||
- [openEuler x86 (910b)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-910b-openEuler-x86.tar.gz)
|
||||
- [openEuler aarch64 (310p)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-310p-openEuler-aarch64.tar.gz)
|
||||
- [openEuler aarch64 (910b)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-910b-openEuler-aarch64.tar.gz)
|
||||
|
||||
<details>
|
||||
|
||||
${{ github.event.head_commit.message }}
|
||||
|
||||
</details>
|
||||
|
||||
- name: Upload release
|
||||
id: upload_release
|
||||
uses: actions/github-script@v3
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -134,3 +134,5 @@ poetry.toml
|
||||
# IDE
|
||||
/*.code-workspace
|
||||
/.windsurf/
|
||||
# emscripten
|
||||
a.out.*
|
||||
|
||||
@@ -33,10 +33,24 @@ endif()
|
||||
|
||||
option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF)
|
||||
|
||||
option(LLAMA_WASM_MEM64 "llama: use 64-bit memory in WASM builds" ON)
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||
|
||||
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON)
|
||||
# Use 64-bit memory to support backend_get_memory queries
|
||||
# TODO: analyze performance impact, see https://spidermonkey.dev/blog/2025/01/15/is-memory64-actually-worth-using
|
||||
if (LLAMA_WASM_MEM64)
|
||||
add_compile_options("-sMEMORY64=1")
|
||||
add_link_options("-sMEMORY64=1")
|
||||
endif()
|
||||
add_link_options("-sALLOW_MEMORY_GROWTH=1")
|
||||
|
||||
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF)
|
||||
option(LLAMA_BUILD_HTML "llama: build HTML file" ON)
|
||||
if (LLAMA_BUILD_HTML)
|
||||
set(CMAKE_EXECUTABLE_SUFFIX ".html")
|
||||
endif()
|
||||
else()
|
||||
if (MINGW)
|
||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||
@@ -58,6 +72,12 @@ if (MSVC)
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
|
||||
endif()
|
||||
|
||||
if (LLAMA_STANDALONE)
|
||||
# enable parallel builds for msbuild
|
||||
list(APPEND CMAKE_VS_GLOBALS UseMultiToolTask=true)
|
||||
list(APPEND CMAKE_VS_GLOBALS EnforceProcessCountAcrossBuilds=true)
|
||||
endif()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
|
||||
set(LLAMA_TOOLS_INSTALL_DEFAULT OFF)
|
||||
else()
|
||||
@@ -179,11 +199,6 @@ if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML)
|
||||
# ... otherwise assume ggml is added by a parent CMakeLists.txt
|
||||
endif()
|
||||
|
||||
if (MINGW)
|
||||
# Target Windows 8 for PrefetchVirtualMemory
|
||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||
endif()
|
||||
|
||||
#
|
||||
# build the library
|
||||
#
|
||||
|
||||
@@ -10,13 +10,16 @@
|
||||
/common/arg.* @ggerganov
|
||||
/common/base64.hpp.* @ggerganov
|
||||
/common/build-info.* @ggerganov
|
||||
/common/chat-peg-parser.* @aldehir
|
||||
/common/common.* @ggerganov
|
||||
/common/console.* @ggerganov
|
||||
/common/http.* @angt
|
||||
/common/llguidance.* @ggerganov
|
||||
/common/log.* @ggerganov
|
||||
/common/peg-parser.* @aldehir
|
||||
/common/sampling.* @ggerganov
|
||||
/common/speculative.* @ggerganov
|
||||
/common/unicode.* @aldehir
|
||||
/convert_*.py @CISC
|
||||
/examples/batched.swift/ @ggerganov
|
||||
/examples/batched/ @ggerganov
|
||||
|
||||
@@ -52,6 +52,8 @@ add_library(${TARGET} STATIC
|
||||
chat-parser.h
|
||||
chat-parser-xml-toolcall.h
|
||||
chat-parser-xml-toolcall.cpp
|
||||
chat-peg-parser.cpp
|
||||
chat-peg-parser.h
|
||||
chat.cpp
|
||||
chat.h
|
||||
common.cpp
|
||||
@@ -69,12 +71,16 @@ add_library(${TARGET} STATIC
|
||||
log.h
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
peg-parser.cpp
|
||||
peg-parser.h
|
||||
regex-partial.cpp
|
||||
regex-partial.h
|
||||
sampling.cpp
|
||||
sampling.h
|
||||
speculative.cpp
|
||||
speculative.h
|
||||
unicode.cpp
|
||||
unicode.h
|
||||
)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
#include <thread> // for hardware_concurrency
|
||||
#include <vector>
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
#ifdef __linux__
|
||||
#include <linux/limits.h>
|
||||
#elif defined(_WIN32)
|
||||
@@ -41,6 +42,8 @@
|
||||
#else
|
||||
#include <sys/syslimits.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#include "chat-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "peg-parser.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -1483,6 +1485,11 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
|
||||
return common_chat_peg_parse(syntax.parser, input, is_partial, syntax);
|
||||
}
|
||||
common_chat_msg_parser builder(input, is_partial, syntax);
|
||||
try {
|
||||
common_chat_parse(builder);
|
||||
@@ -1500,3 +1507,36 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
if (parser.empty()) {
|
||||
throw std::runtime_error("Failed to parse due to missing parser definition.");
|
||||
}
|
||||
|
||||
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(syntax.format), input.c_str());
|
||||
|
||||
common_peg_parse_context ctx(input, is_partial);
|
||||
auto result = parser.parse(ctx);
|
||||
if (result.fail()) {
|
||||
throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end));
|
||||
}
|
||||
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
|
||||
if (syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE) {
|
||||
auto mapper = common_chat_peg_native_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
} else if (syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
|
||||
auto mapper = common_chat_peg_constructed_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
} else {
|
||||
// Generic mapper
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
}
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
114
common/chat-peg-parser.cpp
Normal file
114
common/chat-peg-parser.cpp
Normal file
@@ -0,0 +1,114 @@
|
||||
#include "chat-peg-parser.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
static std::string_view trim_trailing_space(std::string_view sv) {
|
||||
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
|
||||
sv.remove_suffix(1);
|
||||
}
|
||||
return sv;
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
|
||||
arena.visit(result, [this](const common_peg_ast_node & node) {
|
||||
map(node);
|
||||
});
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
||||
bool is_reasoning = node.tag == common_chat_peg_builder::REASONING;
|
||||
bool is_content = node.tag == common_chat_peg_builder::CONTENT;
|
||||
|
||||
if (is_reasoning) {
|
||||
result.reasoning_content = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_content) {
|
||||
result.content = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) {
|
||||
common_chat_peg_mapper::map(node);
|
||||
|
||||
bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN;
|
||||
bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME;
|
||||
bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID;
|
||||
bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS;
|
||||
|
||||
if (is_tool_open) {
|
||||
result.tool_calls.emplace_back();
|
||||
current_tool = &result.tool_calls.back();
|
||||
}
|
||||
|
||||
if (is_tool_id && current_tool) {
|
||||
current_tool->id = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_tool_name && current_tool) {
|
||||
current_tool->name = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_tool_args && current_tool) {
|
||||
current_tool->arguments = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
|
||||
common_chat_peg_mapper::map(node);
|
||||
|
||||
bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN;
|
||||
bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME;
|
||||
bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE;
|
||||
bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN;
|
||||
bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE;
|
||||
bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME;
|
||||
bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE;
|
||||
bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE;
|
||||
|
||||
if (is_tool_open) {
|
||||
result.tool_calls.emplace_back();
|
||||
current_tool = &result.tool_calls.back();
|
||||
arg_count = 0;
|
||||
}
|
||||
|
||||
if (is_tool_name) {
|
||||
current_tool->name = std::string(node.text);
|
||||
current_tool->arguments = "{";
|
||||
}
|
||||
|
||||
if (is_arg_open) {
|
||||
needs_closing_quote = false;
|
||||
}
|
||||
|
||||
if (is_arg_name && current_tool) {
|
||||
if (arg_count > 0) {
|
||||
current_tool->arguments += ",";
|
||||
}
|
||||
current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":";
|
||||
++arg_count;
|
||||
}
|
||||
|
||||
if (is_arg_string && current_tool) {
|
||||
// Serialize to JSON, but exclude the end quote
|
||||
std::string dumped = json(node.text).dump();
|
||||
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
|
||||
needs_closing_quote = true;
|
||||
}
|
||||
|
||||
if (is_arg_close && current_tool) {
|
||||
if (needs_closing_quote) {
|
||||
current_tool->arguments += "\"";
|
||||
}
|
||||
}
|
||||
|
||||
if (is_arg_json && current_tool) {
|
||||
current_tool->arguments += std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_tool_close && current_tool) {
|
||||
current_tool->arguments += "}";
|
||||
}
|
||||
}
|
||||
105
common/chat-peg-parser.h
Normal file
105
common/chat-peg-parser.h
Normal file
@@ -0,0 +1,105 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
class common_chat_peg_builder : public common_peg_parser_builder {
|
||||
public:
|
||||
static constexpr const char * REASONING_BLOCK = "reasoning-block";
|
||||
static constexpr const char * REASONING = "reasoning";
|
||||
static constexpr const char * CONTENT = "content";
|
||||
|
||||
common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); }
|
||||
common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); }
|
||||
common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); }
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_parser(const std::function<common_peg_parser(common_chat_peg_builder & builder)> & fn) {
|
||||
common_chat_peg_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
class common_chat_peg_mapper {
|
||||
public:
|
||||
common_chat_msg & result;
|
||||
|
||||
common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {}
|
||||
|
||||
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
||||
virtual void map(const common_peg_ast_node & node);
|
||||
};
|
||||
|
||||
class common_chat_peg_native_builder : public common_chat_peg_builder {
|
||||
public:
|
||||
static constexpr const char * TOOL = "tool";
|
||||
static constexpr const char * TOOL_OPEN = "tool-open";
|
||||
static constexpr const char * TOOL_CLOSE = "tool-close";
|
||||
static constexpr const char * TOOL_ID = "tool-id";
|
||||
static constexpr const char * TOOL_NAME = "tool-name";
|
||||
static constexpr const char * TOOL_ARGS = "tool-args";
|
||||
|
||||
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
||||
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
||||
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
||||
common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); }
|
||||
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
||||
common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); }
|
||||
};
|
||||
|
||||
class common_chat_peg_native_mapper : public common_chat_peg_mapper {
|
||||
common_chat_tool_call * current_tool;
|
||||
|
||||
public:
|
||||
common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
|
||||
void map(const common_peg_ast_node & node) override;
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_native_parser(const std::function<common_peg_parser(common_chat_peg_native_builder & builder)> & fn) {
|
||||
common_chat_peg_native_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
class common_chat_peg_constructed_builder : public common_chat_peg_builder {
|
||||
public:
|
||||
static constexpr const char * TOOL = "tool";
|
||||
static constexpr const char * TOOL_OPEN = "tool-open";
|
||||
static constexpr const char * TOOL_CLOSE = "tool-close";
|
||||
static constexpr const char * TOOL_NAME = "tool-name";
|
||||
static constexpr const char * TOOL_ARG = "tool-arg";
|
||||
static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open";
|
||||
static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close";
|
||||
static constexpr const char * TOOL_ARG_NAME = "tool-arg-name";
|
||||
static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value";
|
||||
static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value";
|
||||
|
||||
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
||||
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
||||
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
||||
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
||||
common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); }
|
||||
common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); }
|
||||
common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); }
|
||||
common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); }
|
||||
common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); }
|
||||
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); }
|
||||
};
|
||||
|
||||
class common_chat_peg_constructed_mapper : public common_chat_peg_mapper {
|
||||
common_chat_tool_call * current_tool;
|
||||
int arg_count = 0;
|
||||
bool needs_closing_quote = false;
|
||||
|
||||
public:
|
||||
common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
|
||||
void map(const common_peg_ast_node & node) override;
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_constructed_parser(const std::function<common_peg_parser(common_chat_peg_constructed_builder & builder)> & fn) {
|
||||
common_chat_peg_constructed_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
@@ -85,29 +85,36 @@ json common_chat_msg::to_json_oaicompat() const
|
||||
return message;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) {
|
||||
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) {
|
||||
std::vector<common_chat_msg_diff> diffs;
|
||||
if (previous_msg.reasoning_content != new_msg.reasoning_content) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content);
|
||||
}
|
||||
if (previous_msg.content != new_msg.content) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.content_delta = string_diff(previous_msg.content, new_msg.content);
|
||||
if (msg_new.tool_calls.size() > msg_prv.tool_calls.size()) {
|
||||
diffs.reserve(msg_new.tool_calls.size() - msg_prv.tool_calls.size() + 3);
|
||||
} else {
|
||||
diffs.reserve(3);
|
||||
}
|
||||
|
||||
if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) {
|
||||
// TODO: these can become expensive for long messages - how to optimize?
|
||||
if (msg_prv.reasoning_content != msg_new.reasoning_content) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.reasoning_content_delta = string_diff(msg_prv.reasoning_content, msg_new.reasoning_content);
|
||||
}
|
||||
if (msg_prv.content != msg_new.content) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.content_delta = string_diff(msg_prv.content, msg_new.content);
|
||||
}
|
||||
|
||||
if (msg_new.tool_calls.size() < msg_prv.tool_calls.size()) {
|
||||
throw std::runtime_error("Invalid diff: now finding less tool calls!");
|
||||
}
|
||||
|
||||
if (!previous_msg.tool_calls.empty()) {
|
||||
auto idx = previous_msg.tool_calls.size() - 1;
|
||||
const auto & pref = previous_msg.tool_calls[idx];
|
||||
const auto & newf = new_msg.tool_calls[idx];
|
||||
if (!msg_prv.tool_calls.empty()) {
|
||||
const auto idx = msg_prv.tool_calls.size() - 1;
|
||||
const auto & pref = msg_prv.tool_calls[idx];
|
||||
const auto & newf = msg_new.tool_calls[idx];
|
||||
if (pref.name != newf.name) {
|
||||
throw std::runtime_error("Invalid diff: tool call mismatch!");
|
||||
}
|
||||
auto args_diff = string_diff(pref.arguments, newf.arguments);
|
||||
const auto args_diff = string_diff(pref.arguments, newf.arguments);
|
||||
if (!args_diff.empty() || pref.id != newf.id) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.tool_call_index = idx;
|
||||
@@ -118,11 +125,12 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
|
||||
diff.tool_call_delta.arguments = args_diff;
|
||||
}
|
||||
}
|
||||
for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) {
|
||||
for (size_t idx = msg_prv.tool_calls.size(); idx < msg_new.tool_calls.size(); ++idx) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.tool_call_index = idx;
|
||||
diff.tool_call_delta = new_msg.tool_calls[idx];
|
||||
diff.tool_call_delta = msg_new.tool_calls[idx];
|
||||
}
|
||||
|
||||
return diffs;
|
||||
}
|
||||
|
||||
@@ -649,6 +657,9 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
|
||||
case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple";
|
||||
case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native";
|
||||
case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "peg-parser.h"
|
||||
#include <functional>
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
@@ -76,7 +77,7 @@ struct common_chat_msg_diff {
|
||||
size_t tool_call_index = std::string::npos;
|
||||
common_chat_tool_call tool_call_delta;
|
||||
|
||||
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
|
||||
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new);
|
||||
|
||||
bool operator==(const common_chat_msg_diff & other) const {
|
||||
return content_delta == other.content_delta
|
||||
@@ -124,6 +125,11 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||
|
||||
// These are intended to be parsed by the PEG parser
|
||||
COMMON_CHAT_FORMAT_PEG_SIMPLE,
|
||||
COMMON_CHAT_FORMAT_PEG_NATIVE,
|
||||
COMMON_CHAT_FORMAT_PEG_CONSTRUCTED,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
@@ -154,6 +160,7 @@ struct common_chat_params {
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
};
|
||||
|
||||
struct common_chat_syntax {
|
||||
@@ -163,6 +170,7 @@ struct common_chat_syntax {
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool parse_tool_calls = true;
|
||||
common_peg_arena parser = {};
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
@@ -206,6 +214,7 @@ const char* common_chat_format_name(common_chat_format format);
|
||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||
|
||||
|
||||
@@ -902,6 +902,8 @@ std::string fs_get_cache_directory() {
|
||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||
#elif defined(_WIN32)
|
||||
cache_directory = std::getenv("LOCALAPPDATA");
|
||||
#elif defined(__EMSCRIPTEN__)
|
||||
GGML_ABORT("not implemented on this platform");
|
||||
#else
|
||||
# error Unknown architecture
|
||||
#endif
|
||||
|
||||
@@ -12,6 +12,10 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||
#define _WIN32_WINNT 0x0A00
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
#define DIRECTORY_SEPARATOR '\\'
|
||||
#else
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#include "http.h"
|
||||
#endif
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
#ifdef __linux__
|
||||
#include <linux/limits.h>
|
||||
#elif defined(_WIN32)
|
||||
@@ -35,6 +36,8 @@
|
||||
#else
|
||||
#include <sys/syslimits.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
||||
|
||||
// isatty
|
||||
|
||||
1712
common/peg-parser.cpp
Normal file
1712
common/peg-parser.cpp
Normal file
File diff suppressed because it is too large
Load Diff
459
common/peg-parser.h
Normal file
459
common/peg-parser.h
Normal file
@@ -0,0 +1,459 @@
|
||||
#pragma once
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <variant>
|
||||
|
||||
struct common_grammar_builder;
|
||||
|
||||
class common_peg_parser_builder;
|
||||
|
||||
using common_peg_parser_id = size_t;
|
||||
constexpr common_peg_parser_id COMMON_PEG_INVALID_PARSER_ID = static_cast<common_peg_parser_id>(-1);
|
||||
|
||||
using common_peg_ast_id = size_t;
|
||||
constexpr common_peg_ast_id COMMON_PEG_INVALID_AST_ID = static_cast<common_peg_ast_id>(-1);
|
||||
|
||||
// Lightweight wrapper around common_peg_parser_id for convenience
|
||||
class common_peg_parser {
|
||||
common_peg_parser_id id_;
|
||||
common_peg_parser_builder & builder_;
|
||||
|
||||
public:
|
||||
common_peg_parser(const common_peg_parser & other) : id_(other.id_), builder_(other.builder_) {}
|
||||
common_peg_parser(common_peg_parser_id id, common_peg_parser_builder & builder) : id_(id), builder_(builder) {}
|
||||
|
||||
common_peg_parser & operator=(const common_peg_parser & other);
|
||||
common_peg_parser & operator+=(const common_peg_parser & other);
|
||||
common_peg_parser & operator|=(const common_peg_parser & other);
|
||||
|
||||
operator common_peg_parser_id() const { return id_; }
|
||||
common_peg_parser_id id() const { return id_; }
|
||||
|
||||
common_peg_parser_builder & builder() const { return builder_; }
|
||||
|
||||
// Creates a sequence
|
||||
common_peg_parser operator+(const common_peg_parser & other) const;
|
||||
|
||||
// Creates a sequence separated by spaces.
|
||||
common_peg_parser operator<<(const common_peg_parser & other) const;
|
||||
|
||||
// Creates a choice
|
||||
common_peg_parser operator|(const common_peg_parser & other) const;
|
||||
|
||||
common_peg_parser operator+(const char * str) const;
|
||||
common_peg_parser operator+(const std::string & str) const;
|
||||
common_peg_parser operator<<(const char * str) const;
|
||||
common_peg_parser operator<<(const std::string & str) const;
|
||||
common_peg_parser operator|(const char * str) const;
|
||||
common_peg_parser operator|(const std::string & str) const;
|
||||
};
|
||||
|
||||
common_peg_parser operator+(const char * str, const common_peg_parser & p);
|
||||
common_peg_parser operator+(const std::string & str, const common_peg_parser & p);
|
||||
common_peg_parser operator<<(const char * str, const common_peg_parser & p);
|
||||
common_peg_parser operator<<(const std::string & str, const common_peg_parser & p);
|
||||
common_peg_parser operator|(const char * str, const common_peg_parser & p);
|
||||
common_peg_parser operator|(const std::string & str, const common_peg_parser & p);
|
||||
|
||||
enum common_peg_parse_result_type {
|
||||
COMMON_PEG_PARSE_RESULT_FAIL = 0,
|
||||
COMMON_PEG_PARSE_RESULT_SUCCESS = 1,
|
||||
COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT = 2,
|
||||
};
|
||||
|
||||
const char * common_peg_parse_result_type_name(common_peg_parse_result_type type);
|
||||
|
||||
struct common_peg_ast_node {
|
||||
common_peg_ast_id id;
|
||||
std::string rule;
|
||||
std::string tag;
|
||||
size_t start;
|
||||
size_t end;
|
||||
std::string_view text;
|
||||
std::vector<common_peg_ast_id> children;
|
||||
|
||||
bool is_partial = false;
|
||||
};
|
||||
|
||||
struct common_peg_parse_result;
|
||||
|
||||
using common_peg_ast_visitor = std::function<void(const common_peg_ast_node & node)>;
|
||||
|
||||
class common_peg_ast_arena {
|
||||
std::vector<common_peg_ast_node> nodes_;
|
||||
public:
|
||||
common_peg_ast_id add_node(
|
||||
const std::string & rule,
|
||||
const std::string & tag,
|
||||
size_t start,
|
||||
size_t end,
|
||||
std::string_view text,
|
||||
std::vector<common_peg_ast_id> children,
|
||||
bool is_partial = false
|
||||
) {
|
||||
common_peg_ast_id id = nodes_.size();
|
||||
nodes_.push_back({id, rule, tag, start, end, text, std::move(children), is_partial});
|
||||
return id;
|
||||
}
|
||||
|
||||
const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); }
|
||||
|
||||
size_t size() const { return nodes_.size(); }
|
||||
|
||||
void clear() { nodes_.clear(); }
|
||||
|
||||
void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const;
|
||||
void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const;
|
||||
};
|
||||
|
||||
struct common_peg_parse_result {
|
||||
common_peg_parse_result_type type = COMMON_PEG_PARSE_RESULT_FAIL;
|
||||
size_t start = 0;
|
||||
size_t end = 0;
|
||||
|
||||
std::vector<common_peg_ast_id> nodes;
|
||||
|
||||
common_peg_parse_result() = default;
|
||||
|
||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start)
|
||||
: type(type), start(start), end(start) {}
|
||||
|
||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end)
|
||||
: type(type), start(start), end(end) {}
|
||||
|
||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end, std::vector<common_peg_ast_id> nodes)
|
||||
: type(type), start(start), end(end), nodes(std::move(nodes)) {}
|
||||
|
||||
bool fail() const { return type == COMMON_PEG_PARSE_RESULT_FAIL; }
|
||||
bool need_more_input() const { return type == COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT; }
|
||||
bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; }
|
||||
};
|
||||
|
||||
struct common_peg_parse_context {
|
||||
std::string input;
|
||||
bool is_partial;
|
||||
common_peg_ast_arena ast;
|
||||
|
||||
int parse_depth;
|
||||
|
||||
common_peg_parse_context()
|
||||
: is_partial(false), parse_depth(0) {}
|
||||
|
||||
common_peg_parse_context(const std::string & input)
|
||||
: input(input), is_partial(false), parse_depth(0) {}
|
||||
|
||||
common_peg_parse_context(const std::string & input, bool is_partial)
|
||||
: input(input), is_partial(is_partial), parse_depth(0) {}
|
||||
};
|
||||
|
||||
class common_peg_arena;
|
||||
|
||||
// Parser variants
|
||||
struct common_peg_epsilon_parser {};
|
||||
|
||||
struct common_peg_start_parser {};
|
||||
|
||||
struct common_peg_end_parser {};
|
||||
|
||||
struct common_peg_literal_parser {
|
||||
std::string literal;
|
||||
};
|
||||
|
||||
struct common_peg_sequence_parser {
|
||||
std::vector<common_peg_parser_id> children;
|
||||
};
|
||||
|
||||
struct common_peg_choice_parser {
|
||||
std::vector<common_peg_parser_id> children;
|
||||
};
|
||||
|
||||
struct common_peg_repetition_parser {
|
||||
common_peg_parser_id child;
|
||||
int min_count;
|
||||
int max_count; // -1 for unbounded
|
||||
};
|
||||
|
||||
struct common_peg_and_parser {
|
||||
common_peg_parser_id child;
|
||||
};
|
||||
|
||||
struct common_peg_not_parser {
|
||||
common_peg_parser_id child;
|
||||
};
|
||||
|
||||
struct common_peg_any_parser {};
|
||||
|
||||
struct common_peg_space_parser {};
|
||||
|
||||
struct common_peg_chars_parser {
|
||||
struct char_range {
|
||||
uint32_t start;
|
||||
uint32_t end;
|
||||
bool contains(uint32_t codepoint) const { return codepoint >= start && codepoint <= end; }
|
||||
};
|
||||
|
||||
std::string pattern;
|
||||
std::vector<char_range> ranges;
|
||||
bool negated;
|
||||
int min_count;
|
||||
int max_count; // -1 for unbounded
|
||||
};
|
||||
|
||||
struct common_peg_json_string_parser {};
|
||||
|
||||
struct common_peg_until_parser {
|
||||
std::vector<std::string> delimiters;
|
||||
};
|
||||
|
||||
struct common_peg_schema_parser {
|
||||
common_peg_parser_id child;
|
||||
std::string name;
|
||||
std::shared_ptr<nlohmann::ordered_json> schema;
|
||||
|
||||
// Indicates if the GBNF should accept a raw string that matches the schema.
|
||||
bool raw;
|
||||
};
|
||||
|
||||
struct common_peg_rule_parser {
|
||||
std::string name;
|
||||
common_peg_parser_id child;
|
||||
bool trigger;
|
||||
};
|
||||
|
||||
struct common_peg_ref_parser {
|
||||
std::string name;
|
||||
};
|
||||
|
||||
struct common_peg_atomic_parser {
|
||||
common_peg_parser_id child;
|
||||
};
|
||||
|
||||
struct common_peg_tag_parser {
|
||||
common_peg_parser_id child;
|
||||
std::string tag;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
common_peg_start_parser,
|
||||
common_peg_end_parser,
|
||||
common_peg_literal_parser,
|
||||
common_peg_sequence_parser,
|
||||
common_peg_choice_parser,
|
||||
common_peg_repetition_parser,
|
||||
common_peg_and_parser,
|
||||
common_peg_not_parser,
|
||||
common_peg_any_parser,
|
||||
common_peg_space_parser,
|
||||
common_peg_chars_parser,
|
||||
common_peg_json_string_parser,
|
||||
common_peg_until_parser,
|
||||
common_peg_schema_parser,
|
||||
common_peg_rule_parser,
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
std::vector<common_peg_parser_variant> parsers_;
|
||||
std::unordered_map<std::string, common_peg_parser_id> rules_;
|
||||
common_peg_parser_id root_ = COMMON_PEG_INVALID_PARSER_ID;
|
||||
|
||||
public:
|
||||
const common_peg_parser_variant & get(common_peg_parser_id id) const { return parsers_.at(id); }
|
||||
common_peg_parser_variant & get(common_peg_parser_id id) { return parsers_.at(id); }
|
||||
|
||||
size_t size() const { return parsers_.size(); }
|
||||
bool empty() const { return parsers_.empty(); }
|
||||
|
||||
common_peg_parser_id get_rule(const std::string & name) const;
|
||||
bool has_rule(const std::string & name) const { return rules_.find(name) != rules_.end(); }
|
||||
|
||||
common_peg_parser_id root() const { return root_; }
|
||||
void set_root(common_peg_parser_id id) { root_ = id; }
|
||||
|
||||
common_peg_parse_result parse(common_peg_parse_context & ctx, size_t start = 0) const;
|
||||
common_peg_parse_result parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const;
|
||||
|
||||
void resolve_refs();
|
||||
|
||||
void build_grammar(const common_grammar_builder & builder, bool lazy = false) const;
|
||||
|
||||
std::string dump(common_peg_parser_id id) const;
|
||||
|
||||
nlohmann::json to_json() const;
|
||||
static common_peg_arena from_json(const nlohmann::json & j);
|
||||
|
||||
std::string save() const;
|
||||
void load(const std::string & data);
|
||||
|
||||
friend class common_peg_parser_builder;
|
||||
|
||||
private:
|
||||
common_peg_parser_id add_parser(common_peg_parser_variant parser);
|
||||
void add_rule(const std::string & name, common_peg_parser_id id);
|
||||
|
||||
common_peg_parser_id resolve_ref(common_peg_parser_id id);
|
||||
};
|
||||
|
||||
class common_peg_parser_builder {
|
||||
common_peg_arena arena_;
|
||||
|
||||
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
|
||||
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
|
||||
|
||||
public:
|
||||
common_peg_parser_builder();
|
||||
|
||||
// Match nothing, always succeed.
|
||||
// S -> ε
|
||||
common_peg_parser eps() { return add(common_peg_epsilon_parser{}); }
|
||||
|
||||
// Matches the start of the input.
|
||||
// S -> ^
|
||||
common_peg_parser start() { return add(common_peg_start_parser{}); }
|
||||
|
||||
// Matches the end of the input.
|
||||
// S -> $
|
||||
common_peg_parser end() { return add(common_peg_end_parser{}); }
|
||||
|
||||
// Matches an exact literal string.
|
||||
// S -> "hello"
|
||||
common_peg_parser literal(const std::string & literal) { return add(common_peg_literal_parser{literal}); }
|
||||
|
||||
// Matches a sequence of parsers in order, all must succeed.
|
||||
// S -> A B C
|
||||
common_peg_parser sequence() { return add(common_peg_sequence_parser{}); }
|
||||
common_peg_parser sequence(const std::vector<common_peg_parser_id> & parsers);
|
||||
common_peg_parser sequence(const std::vector<common_peg_parser> & parsers);
|
||||
common_peg_parser sequence(std::initializer_list<common_peg_parser> parsers);
|
||||
|
||||
// Matches the first parser that succeeds from a list of alternatives.
|
||||
// S -> A | B | C
|
||||
common_peg_parser choice() { return add(common_peg_choice_parser{}); }
|
||||
common_peg_parser choice(const std::vector<common_peg_parser_id> & parsers);
|
||||
common_peg_parser choice(const std::vector<common_peg_parser> & parsers);
|
||||
common_peg_parser choice(std::initializer_list<common_peg_parser> parsers);
|
||||
|
||||
// Matches one or more repetitions of a parser.
|
||||
// S -> A+
|
||||
common_peg_parser one_or_more(const common_peg_parser & p) { return repeat(p, 1, -1); }
|
||||
|
||||
// Matches zero or more repetitions of a parser, always succeeds.
|
||||
// S -> A*
|
||||
common_peg_parser zero_or_more(const common_peg_parser & p) { return repeat(p, 0, -1); }
|
||||
|
||||
// Matches zero or one occurrence of a parser, always succeeds.
|
||||
// S -> A?
|
||||
common_peg_parser optional(const common_peg_parser & p) { return repeat(p, 0, 1); }
|
||||
|
||||
// Positive lookahead: succeeds if child parser succeeds, consumes no input.
|
||||
// S -> &A
|
||||
common_peg_parser peek(const common_peg_parser & p) { return add(common_peg_and_parser{p}); }
|
||||
|
||||
// Negative lookahead: succeeds if child parser fails, consumes no input.
|
||||
// S -> !A
|
||||
common_peg_parser negate(const common_peg_parser & p) { return add(common_peg_not_parser{p}); }
|
||||
|
||||
// Matches any single character.
|
||||
// S -> .
|
||||
common_peg_parser any() { return add(common_peg_any_parser{}); }
|
||||
|
||||
// Matches between min and max repetitions of characters from a character class.
|
||||
// S -> [a-z]{m,n}
|
||||
//
|
||||
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
|
||||
common_peg_parser chars(const std::string & classes, int min = 1, int max = -1);
|
||||
|
||||
// Creates a lightweight reference to a named rule (resolved during build()).
|
||||
// Use this for forward references in recursive grammars.
|
||||
// expr_ref -> expr
|
||||
common_peg_parser ref(const std::string & name) { return add(common_peg_ref_parser{name}); }
|
||||
|
||||
// Matches zero or more whitespace characters (space, tab, newline).
|
||||
// S -> [ \t\n]*
|
||||
common_peg_parser space() { return add(common_peg_space_parser{}); }
|
||||
|
||||
// Matches all characters until a delimiter is found (delimiter not consumed).
|
||||
// S -> (!delim .)*
|
||||
common_peg_parser until(const std::string & delimiter) { return add(common_peg_until_parser{{delimiter}}); }
|
||||
|
||||
// Matches all characters until one of the delimiters in the list is found (delimiter not consumed).
|
||||
// S -> (!delim .)*
|
||||
common_peg_parser until_one_of(const std::vector<std::string> & delimiters) { return add(common_peg_until_parser{delimiters}); }
|
||||
|
||||
// Matches everything
|
||||
// S -> .*
|
||||
common_peg_parser rest() { return until_one_of({}); }
|
||||
|
||||
// Matches between min and max repetitions of a parser (inclusive).
|
||||
// S -> A{m,n}
|
||||
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
|
||||
common_peg_parser repeat(const common_peg_parser & p, int min, int max) { return add(common_peg_repetition_parser{p, min,max}); }
|
||||
|
||||
// Matches exactly n repetitions of a parser.
|
||||
// S -> A{n}
|
||||
common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); }
|
||||
|
||||
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
|
||||
// value -> object | array | string | number | true | false | null
|
||||
common_peg_parser json();
|
||||
common_peg_parser json_object();
|
||||
common_peg_parser json_string();
|
||||
common_peg_parser json_array();
|
||||
common_peg_parser json_number();
|
||||
common_peg_parser json_bool();
|
||||
common_peg_parser json_null();
|
||||
|
||||
// Matches JSON string content without the surrounding quotes.
|
||||
// Useful for extracting content within a JSON string.
|
||||
common_peg_parser json_string_content();
|
||||
|
||||
// Matches a JSON object member with a key and associated parser as the
|
||||
// value.
|
||||
common_peg_parser json_member(const std::string & key, const common_peg_parser & p);
|
||||
|
||||
// Wraps a parser with JSON schema metadata for grammar generation.
|
||||
// Used internally to convert JSON schemas to GBNF grammar rules.
|
||||
common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false);
|
||||
|
||||
// Creates a named rule, stores it in the grammar, and returns a ref.
|
||||
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
|
||||
// auto json = p.rule("json", json_obj | json_arr | ...)
|
||||
common_peg_parser rule(const std::string & name, const common_peg_parser & p, bool trigger = false);
|
||||
|
||||
// Creates a named rule using a builder function, and returns a ref.
|
||||
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
|
||||
// auto json = p.rule("json", [&]() { return json_object() | json_array() | ... })
|
||||
common_peg_parser rule(const std::string & name, const std::function<common_peg_parser()> & builder, bool trigger = false);
|
||||
|
||||
// Creates a trigger rule. When generating a lazy grammar from the parser,
|
||||
// only trigger rules and descendents are emitted.
|
||||
common_peg_parser trigger_rule(const std::string & name, const common_peg_parser & p) { return rule(name, p, true); }
|
||||
common_peg_parser trigger_rule(const std::string & name, const std::function<common_peg_parser()> & builder) { return rule(name, builder, true); }
|
||||
|
||||
// Creates an atomic parser. Atomic parsers do not create an AST node if
|
||||
// the child results in a partial parse, i.e. NEEDS_MORE_INPUT. This is
|
||||
// intended for situations where partial output is undesirable.
|
||||
common_peg_parser atomic(const common_peg_parser & p) { return add(common_peg_atomic_parser{p}); }
|
||||
|
||||
// Tags create nodes in the generated AST for semantic purposes.
|
||||
// Unlike rules, you can tag multiple nodes with the same tag.
|
||||
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
};
|
||||
|
||||
// Helper function for building parsers
|
||||
common_peg_arena build_peg_parser(const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn);
|
||||
64
common/unicode.cpp
Normal file
64
common/unicode.cpp
Normal file
@@ -0,0 +1,64 @@
|
||||
#include "unicode.h"
|
||||
|
||||
// implementation adopted from src/unicode.cpp
|
||||
|
||||
size_t utf8_sequence_length(unsigned char first_byte) {
|
||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t highbits = static_cast<uint8_t>(first_byte) >> 4;
|
||||
return lookup[highbits];
|
||||
}
|
||||
|
||||
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) {
|
||||
if (offset >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
|
||||
// ASCII fast path
|
||||
if (!(input[offset] & 0x80)) {
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, input[offset], 1);
|
||||
}
|
||||
|
||||
// Invalid: continuation byte as first byte
|
||||
if (!(input[offset] & 0x40)) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
|
||||
// 2-byte sequence
|
||||
if (!(input[offset] & 0x20)) {
|
||||
if (offset + 1 >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
if ((input[offset + 1] & 0xc0) != 0x80) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
auto result = ((input[offset] & 0x1f) << 6) | (input[offset + 1] & 0x3f);
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 2);
|
||||
}
|
||||
|
||||
// 3-byte sequence
|
||||
if (!(input[offset] & 0x10)) {
|
||||
if (offset + 2 >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
auto result = ((input[offset] & 0x0f) << 12) | ((input[offset + 1] & 0x3f) << 6) | (input[offset + 2] & 0x3f);
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 3);
|
||||
}
|
||||
|
||||
// 4-byte sequence
|
||||
if (!(input[offset] & 0x08)) {
|
||||
if (offset + 3 >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80 || (input[offset + 3] & 0xc0) != 0x80) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
auto result = ((input[offset] & 0x07) << 18) | ((input[offset + 1] & 0x3f) << 12) | ((input[offset + 2] & 0x3f) << 6) | (input[offset + 3] & 0x3f);
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 4);
|
||||
}
|
||||
|
||||
// Invalid first byte
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
22
common/unicode.h
Normal file
22
common/unicode.h
Normal file
@@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <string_view>
|
||||
|
||||
// UTF-8 parsing utilities for streaming-aware unicode support
|
||||
|
||||
struct utf8_parse_result {
|
||||
uint32_t codepoint; // Decoded codepoint (only valid if status == SUCCESS)
|
||||
size_t bytes_consumed; // How many bytes this codepoint uses (1-4)
|
||||
enum status { SUCCESS, INCOMPLETE, INVALID } status;
|
||||
|
||||
utf8_parse_result(enum status s, uint32_t cp = 0, size_t bytes = 0)
|
||||
: codepoint(cp), bytes_consumed(bytes), status(s) {}
|
||||
};
|
||||
|
||||
// Determine the expected length of a UTF-8 sequence from its first byte
|
||||
// Returns 0 for invalid first bytes
|
||||
size_t utf8_sequence_length(unsigned char first_byte);
|
||||
|
||||
// Parse a single UTF-8 codepoint from input
|
||||
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset);
|
||||
288
docs/development/parsing.md
Normal file
288
docs/development/parsing.md
Normal file
@@ -0,0 +1,288 @@
|
||||
# Parsing Model Output
|
||||
|
||||
The `common` library contains a PEG parser implementation suitable for parsing
|
||||
model output.
|
||||
|
||||
Types with the prefix `common_peg_*` are intended for general use and may have
|
||||
applications beyond parsing model output, such as parsing user-provided regex
|
||||
patterns.
|
||||
|
||||
Types with the prefix `common_chat_peg_*` are specialized helpers for model
|
||||
output.
|
||||
|
||||
The parser features:
|
||||
|
||||
- Partial parsing of streaming input
|
||||
- Built-in JSON parsers
|
||||
- AST generation with semantics via "tagged" nodes
|
||||
|
||||
## Example
|
||||
|
||||
Below is a contrived example demonstrating how to use the PEG parser to parse
|
||||
output from a model that emits arguments as JSON.
|
||||
|
||||
```cpp
|
||||
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
|
||||
// Build a choice of all available tools
|
||||
auto tool_choice = p.choice();
|
||||
for (const auto & tool : tools) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
const auto & schema = function.at("parameters");
|
||||
|
||||
auto tool_name = p.json_member("name", "\"" + p.literal(name) + "\"");
|
||||
auto tool_args = p.json_member("arguments", p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, "{" << tool_name << "," << tool_args << "}");
|
||||
}
|
||||
|
||||
// Define the tool call structure: <tool_call>[{tool}]</tool_call>
|
||||
auto tool_call = p.trigger_rule("tool-call",
|
||||
p.sequence({
|
||||
p.literal("<tool_call>["),
|
||||
tool_choice,
|
||||
p.literal("]</tool_call>")
|
||||
})
|
||||
);
|
||||
|
||||
// Parser accepts content, optionally followed by a tool call
|
||||
return p.sequence({
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.optional(tool_call),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
For a more complete example, see `test_example_native()` in
|
||||
[tests/test-chat-peg-parser.cpp](tests/test-chat-peg-parser.cpp).
|
||||
|
||||
## Parsers/Combinators
|
||||
|
||||
### Basic Matchers
|
||||
|
||||
- **`eps()`** - Matches nothing and always succeeds (epsilon/empty match)
|
||||
- **`start()`** - Matches the start of input (anchor `^`)
|
||||
- **`end()`** - Matches the end of input (anchor `$`)
|
||||
- **`literal(string)`** - Matches an exact literal string
|
||||
- **`any()`** - Matches any single character (`.`)
|
||||
|
||||
### Combinators
|
||||
|
||||
- **`sequence(...)`** - Matches parsers in order; all must succeed
|
||||
- **`choice(...)`** - Matches the first parser that succeeds from alternatives (ordered choice)
|
||||
- **`one_or_more(p)`** - Matches one or more repetitions (`+`)
|
||||
- **`zero_or_more(p)`** - Matches zero or more repetitions (`*`)
|
||||
- **`optional(p)`** - Matches zero or one occurrence (`?`)
|
||||
- **`repeat(p, min, max)`** - Matches between min and max repetitions (use `-1` for unbounded)
|
||||
- **`repeat(p, n)`** - Matches exactly n repetitions
|
||||
|
||||
### Lookahead
|
||||
|
||||
- **`peek(p)`** - Positive lookahead: succeeds if parser succeeds without consuming input (`&`)
|
||||
- **`negate(p)`** - Negative lookahead: succeeds if parser fails without consuming input (`!`)
|
||||
|
||||
### Character Classes & Utilities
|
||||
|
||||
- **`chars(classes, min, max)`** - Matches repetitions of characters from a character class
|
||||
- **`space()`** - Matches zero or more whitespace characters (space, tab, newline)
|
||||
- **`until(delimiter)`** - Matches characters until delimiter is found (delimiter not consumed)
|
||||
- **`until_one_of(delimiters)`** - Matches characters until any delimiter in the list is found
|
||||
- **`rest()`** - Matches everything remaining (`.*`)
|
||||
|
||||
### JSON Parsers
|
||||
|
||||
- **`json()`** - Complete JSON parser (objects, arrays, strings, numbers, booleans, null)
|
||||
- **`json_object()`** - JSON object parser
|
||||
- **`json_array()`** - JSON array parser
|
||||
- **`json_string()`** - JSON string parser
|
||||
- **`json_number()`** - JSON number parser
|
||||
- **`json_bool()`** - JSON boolean parser
|
||||
- **`json_null()`** - JSON null parser
|
||||
- **`json_string_content()`** - JSON string content without surrounding quotes
|
||||
- **`json_member(key, p)`** - JSON object member with specific key and value parser
|
||||
|
||||
### Grammar Building
|
||||
|
||||
- **`ref(name)`** - Creates a lightweight reference to a named rule (for recursive grammars)
|
||||
- **`rule(name, p, trigger)`** - Creates a named rule and returns a reference
|
||||
- **`trigger_rule(name, p)`** - Creates a trigger rule (entry point for lazy grammar generation)
|
||||
- **`schema(p, name, schema, raw)`** - Wraps parser with JSON schema metadata for grammar generation
|
||||
|
||||
### AST Control
|
||||
|
||||
- **`atomic(p)`** - Prevents AST node creation for partial parses
|
||||
- **`tag(tag, p)`** - Creates AST nodes with semantic tags (multiple nodes can share tags)
|
||||
|
||||
## GBNF Grammar Generation
|
||||
|
||||
The PEG parser also acts as a convenient DSL for generating GBNF grammars, with
|
||||
some exceptions.
|
||||
|
||||
```cpp
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(params.tools, [&](const json & fn) {
|
||||
builder.resolve_refs(fn.at("parameters"));
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
```
|
||||
|
||||
The notable exception is the `negate(p)` lookahead parser, which cannot be
|
||||
defined as a CFG grammar and therefore does not produce a rule. Its usage
|
||||
should be limited and preferably hidden behind a `schema()` parser. In many
|
||||
cases, `until(delimiter)` or `until_one_of(delimiters)` is a better choice.
|
||||
|
||||
Another limitation is that the PEG parser requires an unambiguous grammar. In
|
||||
contrast, the `llama-grammar` implementation can support ambiguous grammars,
|
||||
though they are difficult to parse.
|
||||
|
||||
### Lazy Grammars
|
||||
|
||||
During lazy grammar generation, only rules reachable from a `trigger_rule(p)`
|
||||
are emitted in the grammar. All trigger rules are added as alternations in the
|
||||
root rule. It is still necessary to define trigger patterns, as the parser has
|
||||
no interaction with the grammar sampling.
|
||||
|
||||
### JSON Schema
|
||||
|
||||
The `schema(p, name, schema, raw)` parser will use the `json-schema-to-grammar`
|
||||
implementation to generate the grammar instead of the underlying parser.
|
||||
|
||||
The `raw` option emits a grammar suitable for a raw string instead of a JSON
|
||||
string. In other words, it won't be wrapped in quotes or require escaping
|
||||
quotes. It should only be used when `type == "string"`.
|
||||
|
||||
The downside is that it can potentially lead to ambiguous grammars. For
|
||||
example, if a user provides the pattern `^.*$`, the following grammar may be
|
||||
generated:
|
||||
|
||||
```
|
||||
root ::= "<arg>" .* "</arg>"
|
||||
```
|
||||
|
||||
This creates an ambiguous grammar that cannot be parsed by the PEG parser. To
|
||||
help mitigate this, if `.*` is found in the pattern, the grammar from the
|
||||
underlying parser will be emitted instead.
|
||||
|
||||
## Common AST Shapes for Chat Parsing
|
||||
|
||||
Most model output can be placed in one of the following categories:
|
||||
|
||||
- Content only
|
||||
- Tool calling with arguments emitted as a single JSON object
|
||||
- Tool calling with arguments emitted as separate entities, either XML
|
||||
(Qwen3-Coder, MiniMax M2) or pseudo-function calls (LFM2)
|
||||
|
||||
To provide broad coverage,
|
||||
[`common/chat-peg-parser.h`](common/chat-peg-parser.h) contains builders and
|
||||
mappers that help create parsers and visitors/extractors for these types. They
|
||||
require parsers to tag nodes to conform to an AST "shape". This normalization
|
||||
makes it easy to extract information and generalize parsing.
|
||||
|
||||
### Simple
|
||||
|
||||
The `common_chat_peg_builder` builds a `simple` parser that supports
|
||||
content-only models with optional reasoning.
|
||||
|
||||
- **`reasoning(p)`** - Tag node for extracting `reasoning_content`
|
||||
- **`content(p)`** - Tag node for extracting `content`
|
||||
|
||||
```cpp
|
||||
build_chat_peg_parser([&](common_chat_peg_parser & p) {
|
||||
return p.sequence({
|
||||
p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>"),
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
Use `common_chat_peg_mapper` to extract the content. Note that this is already
|
||||
done for you in `common_chat_peg_parser` when
|
||||
`chat_format == COMMON_CHAT_FORMAT_PEG_SIMPLE`.
|
||||
|
||||
```cpp
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
```
|
||||
|
||||
### Native
|
||||
|
||||
The `common_chat_peg_native_builder` builds a `native` parser suitable for
|
||||
models that emit tool arguments as a direct JSON object.
|
||||
|
||||
- **`reasoning(p)`** - Tag node for `reasoning_content`
|
||||
- **`content(p)`** - Tag node for `content`
|
||||
- **`tool(p)`** - Tag entirety of a single tool call
|
||||
- **`tool_open(p)`** - Tag start of a tool call
|
||||
- **`tool_close(p)`** - Tag end of a tool call
|
||||
- **`tool_id(p)`** - Tag the tool call ID (optional)
|
||||
- **`tool_name(p)`** - Tag the tool name
|
||||
- **`tool_args(p)`** - Tag the tool arguments
|
||||
|
||||
```cpp
|
||||
build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) {
|
||||
auto get_weather_tool = p.tool(p.sequence({
|
||||
p.tool_open(p.literal("{")),
|
||||
p.json_member("name", "\"" + p.tool_name(p.literal("get_weather")) + "\""),
|
||||
p.literal(","),
|
||||
p.json_member("arguments", p.tool_args(p.json())),
|
||||
p.tool_close(p.literal("}"))
|
||||
}));
|
||||
|
||||
return p.sequence({
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.literal("<tool_call>"),
|
||||
get_weather_tool,
|
||||
p.literal("</tool_call>"),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Constructed
|
||||
|
||||
The `common_chat_peg_constructed_builder` builds a `constructed` parser
|
||||
suitable for models that emit tool arguments as separate entities, such as XML
|
||||
tags.
|
||||
|
||||
- **`reasoning(p)`** - Tag node for `reasoning_content`
|
||||
- **`content(p)`** - Tag node for `content`
|
||||
- **`tool(p)`** - Tag entirety of a single tool call
|
||||
- **`tool_open(p)`** - Tag start of a tool call
|
||||
- **`tool_close(p)`** - Tag end of a tool call
|
||||
- **`tool_name(p)`** - Tag the tool name
|
||||
- **`tool_arg(p)`** - Tag a complete tool argument (name + value)
|
||||
- **`tool_arg_open(p)`** - Tag start of a tool argument
|
||||
- **`tool_arg_close(p)`** - Tag end of a tool argument
|
||||
- **`tool_arg_name(p)`** - Tag the argument name
|
||||
- **`tool_arg_string_value(p)`** - Tag string value for the argument
|
||||
- **`tool_arg_json_value(p)`** - Tag JSON value for the argument
|
||||
|
||||
```cpp
|
||||
build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) {
|
||||
auto location_arg = p.tool_arg(
|
||||
p.tool_arg_open("<parameter name=\"" + p.tool_arg_name(p.literal("location")) + "\">"),
|
||||
p.tool_arg_string_value(p.until("</parameter>")),
|
||||
p.tool_arg_close(p.literal("</parameter>"))
|
||||
);
|
||||
|
||||
auto get_weather_tool = p.tool(p.sequence({
|
||||
p.tool_open("<function name=\"" + p.tool_name(p.literal("get_weather")) + "\">"),
|
||||
location_arg,
|
||||
p.tool_close(p.literal("</function>"))
|
||||
}));
|
||||
|
||||
return p.sequence({
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.literal("<tool_call>"),
|
||||
get_weather_tool,
|
||||
p.literal("</tool_call>"),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
@@ -175,11 +175,6 @@ option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requi
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC")
|
||||
|
||||
|
||||
if (MINGW)
|
||||
set(GGML_WIN_VER "0xA00" CACHE STRING "ggml: Windows version")
|
||||
endif()
|
||||
|
||||
# ggml core
|
||||
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
|
||||
option(GGML_CPU "ggml: enable CPU backend" ON)
|
||||
@@ -226,7 +221,7 @@ option(GGML_WEBGPU "ggml: use WebGPU"
|
||||
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
|
||||
option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF)
|
||||
option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)
|
||||
|
||||
option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON)
|
||||
option(GGML_ZDNN "ggml: use zDNN" OFF)
|
||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||
|
||||
@@ -204,6 +204,10 @@
|
||||
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||
# define _WIN32_WINNT 0x0A00
|
||||
#endif
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
@@ -2279,7 +2283,7 @@ extern "C" {
|
||||
float stop,
|
||||
float step);
|
||||
|
||||
#define GGML_KQ_MASK_PAD 64
|
||||
#define GGML_KQ_MASK_PAD 1
|
||||
|
||||
// q: [n_embd_k, n_batch, n_head, ne3 ]
|
||||
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
|
||||
|
||||
@@ -127,10 +127,6 @@ if (NOT MSVC)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (MINGW)
|
||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||
endif()
|
||||
|
||||
#
|
||||
# POSIX conformance
|
||||
#
|
||||
|
||||
@@ -2698,6 +2698,11 @@ struct ggml_cplan ggml_graph_plan(
|
||||
n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS;
|
||||
}
|
||||
|
||||
#if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__)
|
||||
// Emscripten without pthreads support can only use a single thread
|
||||
n_threads = 1;
|
||||
#endif
|
||||
|
||||
size_t work_size = 0;
|
||||
|
||||
struct ggml_cplan cplan;
|
||||
|
||||
@@ -6383,7 +6383,7 @@ static void ggml_compute_forward_im2col_3d_f16(
|
||||
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
||||
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
||||
|
||||
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
||||
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
||||
} else {
|
||||
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
||||
|
||||
@@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
@@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max(
|
||||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
|
||||
const int nbatch_fa) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
@@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
@@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results(
|
||||
template <int DV, int ncols1, int ncols2>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
@@ -790,8 +791,6 @@ void launch_fattn(
|
||||
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
@@ -915,7 +914,7 @@ void launch_fattn(
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
|
||||
} else {
|
||||
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
@@ -970,6 +969,9 @@ void launch_fattn(
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
// TODO other tensor dimensions after removal of WMMA kernel:
|
||||
const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
|
||||
|
||||
GGML_ASSERT(block_dim.x % warp_size == 0);
|
||||
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
||||
(const char *) Q->data,
|
||||
@@ -980,7 +982,7 @@ void launch_fattn(
|
||||
KV_max.ptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
||||
@@ -995,7 +997,7 @@ void launch_fattn(
|
||||
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -501,6 +501,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
||||
const half2 * const __restrict__ K_h2,
|
||||
const half2 * const __restrict__ V_h2,
|
||||
const half * const __restrict__ mask,
|
||||
const uint3 ne01,
|
||||
const float logit_softcap,
|
||||
const float slope,
|
||||
T_KQ * const KQ,
|
||||
@@ -512,7 +513,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
||||
float * const KQ_sum,
|
||||
T_acc * const VKQ,
|
||||
const int k_VKQ_0,
|
||||
const int k_VKQ_max) {
|
||||
const int k_VKQ_max,
|
||||
const int col_Q_0) {
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
@@ -556,7 +558,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
||||
// Apply logit softcap + mask, update KQ_max:
|
||||
#pragma unroll
|
||||
for (int jc0 = 0; jc0 < cpw; ++jc0) {
|
||||
const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2;
|
||||
const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01);
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
|
||||
@@ -736,7 +738,7 @@ static __global__ void flash_attn_tile(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
@@ -781,11 +783,11 @@ static __global__ void flash_attn_tile(
|
||||
const int sequence = blockIdx.z / (ne02/ncols2);
|
||||
const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0);
|
||||
const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
|
||||
|
||||
const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr;
|
||||
const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr;
|
||||
|
||||
const int stride_K2 = nb11 / sizeof(half2);
|
||||
const int stride_V2 = nb21 / sizeof(half2);
|
||||
@@ -842,11 +844,9 @@ static __global__ void flash_attn_tile(
|
||||
for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
|
||||
if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
|
||||
float tmp_f[cpy_ne_D] = {0.0f};
|
||||
if (ncols1 == 1 || col_Q_0 + j < ne01) {
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)>
|
||||
(tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float))
|
||||
+ i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)>
|
||||
(tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
|
||||
+ i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
|
||||
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
@@ -881,23 +881,23 @@ static __global__ void flash_attn_tile(
|
||||
while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
|
||||
constexpr bool oob_check = false;
|
||||
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
|
||||
(Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
|
||||
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
|
||||
k_VKQ_0 += gridDim.y*nbatch_fa;
|
||||
}
|
||||
if (k_VKQ_0 < k_VKQ_max) {
|
||||
constexpr bool oob_check = true;
|
||||
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
|
||||
(Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
|
||||
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
|
||||
}
|
||||
} else {
|
||||
// Branch without out-of-bounds checks.
|
||||
for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
|
||||
constexpr bool oob_check = false;
|
||||
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
|
||||
(Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
|
||||
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1010,13 +1010,13 @@ static __global__ void flash_attn_tile(
|
||||
const int j = jc / ncols2;
|
||||
const int c = jc % ncols2;
|
||||
|
||||
if (ncols1 > 1 && col_Q_0 + j >= ne01) {
|
||||
if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
|
||||
const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
|
||||
|
||||
@@ -33,7 +33,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
@@ -150,7 +150,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
// Set memory to zero if out of bounds:
|
||||
if (ncols > 1 && ic0 + j >= ne01) {
|
||||
if (ncols > 1 && ic0 + j >= int(ne01.z)) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
@@ -201,7 +201,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
|
||||
|
||||
float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
|
||||
if (ncols == 1 || ic0 + j < ne01) {
|
||||
if (ncols == 1 || ic0 + j < int(ne01.z)) {
|
||||
ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
|
||||
}
|
||||
@@ -222,7 +222,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
|
||||
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
|
||||
if (ncols == 1 || ic0 + j < ne01) {
|
||||
if (ncols == 1 || ic0 + j < int(ne01.z)) {
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
|
||||
}
|
||||
@@ -266,7 +266,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) {
|
||||
sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
|
||||
}
|
||||
|
||||
@@ -412,7 +412,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||
if (ncols > 1 && ic0 + j_VKQ >= ne01) {
|
||||
if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -479,7 +479,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= KQ_sum[j_VKQ];
|
||||
}
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
|
||||
dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,8 +489,8 @@ static __global__ void flash_attn_ext_vec(
|
||||
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) {
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) {
|
||||
dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
|
||||
@@ -38,14 +38,14 @@ static __global__ void flash_attn_ext_f16(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -149,7 +149,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
if (i0 + warp_size > D && i >= D) {
|
||||
break;
|
||||
}
|
||||
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||
KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,7 +218,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
||||
KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
|
||||
__half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
||||
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
|
||||
}
|
||||
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
|
||||
@@ -270,7 +271,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
||||
KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
||||
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
|
||||
}
|
||||
KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||
@@ -431,7 +432,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j_VKQ = j0 + threadIdx.y;
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
if (ic0 + j_VKQ >= int(ne01.z)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -442,7 +443,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
||||
}
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
||||
@@ -481,7 +482,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
|
||||
}
|
||||
|
||||
constexpr int get_max_power_of_2(int x) {
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
||||
#if defined(GGML_USE_MUSA)
|
||||
#define GGML_USE_WMMA_FATTN
|
||||
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
||||
#endif // defined(GGML_USE_MUSA)
|
||||
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
|
||||
|
||||
@@ -12,13 +12,13 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if constexpr (ncols2 <= 8) {
|
||||
if (Q->ne[1] <= 8/ncols2) {
|
||||
if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 16/ncols2) {
|
||||
if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
@@ -41,7 +41,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
@@ -275,8 +275,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// If Turing tensor cores available, use them:
|
||||
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
// If Turing tensor cores are available, use them:
|
||||
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||
@@ -297,7 +297,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
int gqa_ratio_eff = 1;
|
||||
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
|
||||
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
|
||||
gqa_ratio_eff *= 2;
|
||||
}
|
||||
if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
if (Q->ne[1] * gqa_ratio_eff <= 16) {
|
||||
return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.
|
||||
}
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
|
||||
@@ -68,10 +68,31 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
|
||||
|
||||
namespace ggml_cuda_mma {
|
||||
|
||||
// Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
|
||||
// effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
|
||||
// In those cases the data can be split in different ways across the warp.
|
||||
enum data_layout {
|
||||
// By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
|
||||
// For the A/C matrices this means I major == row major, J major == column major.
|
||||
// For the B matrix this means I major == column major, J major == row major.
|
||||
// MIRRORED == Each data value is held exactly once per thread subgroup.
|
||||
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
|
||||
DATA_LAYOUT_I_MAJOR_MIRRORED = 10,
|
||||
DATA_LAYOUT_J_MAJOR_MIRRORED = 20,
|
||||
};
|
||||
// Implemented mma combinations are:
|
||||
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
||||
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
||||
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
||||
|
||||
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
||||
struct tile {};
|
||||
|
||||
template <int I_, int J_, typename T>
|
||||
struct tile {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 64;
|
||||
@@ -131,9 +152,9 @@ namespace ggml_cuda_mma {
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 32 && J == 8) {
|
||||
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
|
||||
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
|
||||
#else
|
||||
return (l & 2) | (threadIdx.x & ~2);
|
||||
return (l & 2) + (threadIdx.x & ~2);
|
||||
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -143,7 +164,7 @@ namespace ggml_cuda_mma {
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 32 && J == 8) {
|
||||
return (threadIdx.x & 2) | (l & (4 + 1));
|
||||
return (threadIdx.x & 2) + (l & (4 + 1));
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -196,9 +217,9 @@ namespace ggml_cuda_mma {
|
||||
} else if constexpr (I == 8 && J == 8) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((l / 2) * 8) | (threadIdx.x / 4);
|
||||
return ((l / 2) * 8) + (threadIdx.x / 4);
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
|
||||
return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
|
||||
} else {
|
||||
@@ -211,11 +232,11 @@ namespace ggml_cuda_mma {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 8 && J == 8) {
|
||||
return (l * 4) | (threadIdx.x % 4);
|
||||
return (l * 4) + (threadIdx.x % 4);
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((threadIdx.x % 4) * 2) | (l % 2);
|
||||
return ((threadIdx.x % 4) * 2) + (l % 2);
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
|
||||
return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
|
||||
} else {
|
||||
@@ -227,26 +248,24 @@ namespace ggml_cuda_mma {
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, half2> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
||||
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 8) return true;
|
||||
if (I == 32 && J == 8) return true;
|
||||
if (I == 32 && J == 4) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
if constexpr (I == 32 && J == 4) {
|
||||
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
|
||||
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
||||
#else
|
||||
return threadIdx.x;
|
||||
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
@@ -257,7 +276,7 @@ namespace ggml_cuda_mma {
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr ((I == 8 || I == 32) && J == 8) {
|
||||
if constexpr (I == 32 && J == 4) {
|
||||
return l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -307,11 +326,11 @@ namespace ggml_cuda_mma {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return (l * 8) | (threadIdx.x / 4);
|
||||
return (l * 8) + (threadIdx.x / 4);
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((l % 2) * 8) | (threadIdx.x / 4);
|
||||
return ((l % 2) * 8) + (threadIdx.x / 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
|
||||
return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -320,13 +339,13 @@ namespace ggml_cuda_mma {
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return (l * 4) | (threadIdx.x % 4);
|
||||
return (l * 4) + (threadIdx.x % 4);
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((l / 2) * 4) | (threadIdx.x % 4);
|
||||
return ((l / 2) * 4) + (threadIdx.x % 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return ((l & 2) * 2) | (threadIdx.x % 4);
|
||||
return ((l & 2) * 2) + (threadIdx.x % 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -336,14 +355,15 @@ namespace ggml_cuda_mma {
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, nv_bfloat162> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 32;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
@@ -367,9 +387,6 @@ namespace ggml_cuda_mma {
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 8) return true;
|
||||
if (I == 16 && J == 4) return true;
|
||||
@@ -381,9 +398,9 @@ namespace ggml_cuda_mma {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return (l * 8) | (threadIdx.x / 4);
|
||||
return (l * 8) + (threadIdx.x / 4);
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((l % 2) * 8) | (threadIdx.x / 4);
|
||||
return ((l % 2) * 8) + (threadIdx.x / 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -392,11 +409,11 @@ namespace ggml_cuda_mma {
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return (l * 4) | (threadIdx.x % 4);
|
||||
return (l * 4) + (threadIdx.x % 4);
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((l / 2) * 4) | (threadIdx.x % 4);
|
||||
return ((l / 2) * 4) + (threadIdx.x % 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -405,6 +422,73 @@ namespace ggml_cuda_mma {
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||
static constexpr int ne = I * J / (WARP_SIZE/4);
|
||||
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 4) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int /*l*/) {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
|
||||
static constexpr int ne = I * J / (WARP_SIZE/4);
|
||||
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 4) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return ((l / 2) * 4) + (threadIdx.x % 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return ((threadIdx.x / 16) * 2) + (l % 2);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
template <int I, int J>
|
||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||
tile<I, J/2, half2> ret;
|
||||
@@ -422,9 +506,26 @@ namespace ggml_cuda_mma {
|
||||
|
||||
return ret;
|
||||
}
|
||||
#else // Volta
|
||||
template <int I, int J>
|
||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||
tile<I, J/2, half2> ret;
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
|
||||
ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
|
||||
ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
|
||||
|
||||
template <int I, int J, typename T>
|
||||
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
// On Volta FP16 and FP32 tiles have a different memory layout,
|
||||
// for the conversion threads with an offset of 2 need to exchange half their values:
|
||||
ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
|
||||
0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
#endif // defined(TURING_MMA_AVAILABLE)
|
||||
|
||||
template <int I, int J, typename T, data_layout dl>
|
||||
static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||
#pragma unroll
|
||||
@@ -511,18 +612,6 @@ namespace ggml_cuda_mma {
|
||||
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
GGML_UNUSED_VARS(t, xs0, stride);
|
||||
NO_DEVICE_CODE;
|
||||
#else
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#if 1
|
||||
// TODO: more generic handling
|
||||
@@ -533,9 +622,31 @@ namespace ggml_cuda_mma {
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // 1
|
||||
#else
|
||||
tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
|
||||
load_ldmatrix(t16[0], xs0 + 0*stride, stride);
|
||||
load_ldmatrix(t16[1], xs0 + 16*stride, stride);
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
||||
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < t.ne; l0 += 2) {
|
||||
ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
||||
#else
|
||||
GGML_UNUSED_VARS(t, xs0, stride);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
@@ -860,14 +971,14 @@ namespace ggml_cuda_mma {
|
||||
template <typename T1, typename T2, int J, int K>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
|
||||
tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
|
||||
tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
|
||||
tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
|
||||
const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
|
||||
mma(D16[0], A16[0], B);
|
||||
mma(D16[1], A16[1], B);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||
tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
@@ -880,20 +991,30 @@ namespace ggml_cuda_mma {
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
||||
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
|
||||
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
|
||||
#else
|
||||
tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
|
||||
const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
|
||||
mma(D16[0], A16[0], B);
|
||||
mma(D16[1], A16[1], B);
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
|
||||
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
|
||||
@@ -37,23 +37,19 @@ static __global__ void mul_mat_f(
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<tile_B_I, 8, T> tile_B;
|
||||
typedef tile<16, tile_C_J, float> tile_C;
|
||||
|
||||
constexpr bool a_supported = tile_A::supported();
|
||||
constexpr bool b_supported = tile_B::supported();
|
||||
constexpr bool c_supported = tile_C::supported();
|
||||
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||
#else
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
constexpr bool supported = I_16_supported || I_32_supported;
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (!supported) {
|
||||
if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
@@ -248,6 +244,9 @@ static __global__ void mul_mat_f(
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif //VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
@@ -278,27 +277,24 @@ static __global__ void mul_mat_f_ids(
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<tile_B_I, 8, T> tile_B;
|
||||
typedef tile<16, tile_C_J, float> tile_C;
|
||||
|
||||
constexpr bool a_supported = tile_A::supported();
|
||||
constexpr bool b_supported = tile_B::supported();
|
||||
constexpr bool c_supported = tile_C::supported();
|
||||
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||
#else
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
constexpr bool supported = I_16_supported || I_32_supported;
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (!supported) {
|
||||
if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
@@ -517,6 +513,9 @@ static __global__ void mul_mat_f_ids(
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
|
||||
@@ -50,7 +50,7 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
|
||||
if (ppls->data.find(name) == ppls->data.end()) {
|
||||
if (ppls->data.find(name) == ppls->data.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
@@ -146,6 +146,8 @@ struct ggml_metal_library {
|
||||
id<MTLDevice> device;
|
||||
|
||||
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
|
||||
|
||||
NSLock * lock;
|
||||
};
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
@@ -296,9 +298,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
|
||||
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
||||
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -365,6 +368,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -380,20 +384,27 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
|
||||
|
||||
ggml_metal_pipelines_free(lib->pipelines);
|
||||
|
||||
[lib->lock release];
|
||||
|
||||
free(lib);
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
|
||||
return ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
[lib->lock lock];
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
|
||||
[lib->lock unlock];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
|
||||
// note: the pipelines are cached in the library per device, so they are shared across all metal contexts
|
||||
ggml_critical_section_start();
|
||||
[lib->lock lock];
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
if (res) {
|
||||
ggml_critical_section_end();
|
||||
[lib->lock unlock];
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -414,7 +425,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
|
||||
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
|
||||
}
|
||||
if (!mtl_function) {
|
||||
ggml_critical_section_end();
|
||||
[lib->lock unlock];
|
||||
|
||||
GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
|
||||
if (error) {
|
||||
@@ -433,7 +444,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
|
||||
(int) res->obj.threadExecutionWidth);
|
||||
|
||||
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
|
||||
ggml_critical_section_end();
|
||||
[lib->lock unlock];
|
||||
|
||||
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
|
||||
|
||||
@@ -443,7 +454,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
|
||||
ggml_metal_pipelines_add(lib->pipelines, name, res);
|
||||
}
|
||||
|
||||
ggml_critical_section_end();
|
||||
[lib->lock unlock];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -39,8 +39,23 @@ add_dependencies(ggml-webgpu generate_shaders)
|
||||
if(EMSCRIPTEN)
|
||||
set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg")
|
||||
|
||||
target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||
target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||
if(NOT EMDAWNWEBGPU_DIR)
|
||||
# default built-in port
|
||||
target_compile_options(ggml-webgpu PRIVATE "--use-port=emdawnwebgpu")
|
||||
target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu")
|
||||
else()
|
||||
# custom port
|
||||
target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||
target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||
endif()
|
||||
|
||||
if (GGML_WEBGPU_JSPI)
|
||||
target_compile_options(ggml-webgpu PRIVATE "-fwasm-exceptions")
|
||||
target_link_options(ggml-webgpu INTERFACE "-sJSPI" "-fwasm-exceptions")
|
||||
else()
|
||||
target_compile_options(ggml-webgpu PRIVATE "-fexceptions")
|
||||
target_link_options(ggml-webgpu INTERFACE "-sASYNCIFY" "-exceptions")
|
||||
endif()
|
||||
else()
|
||||
find_package(Dawn REQUIRED)
|
||||
set(DawnWebGPU_TARGET dawn::webgpu_dawn)
|
||||
@@ -48,6 +63,9 @@ endif()
|
||||
|
||||
if (GGML_WEBGPU_DEBUG)
|
||||
target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1)
|
||||
if(EMSCRIPTEN)
|
||||
target_link_options(ggml-webgpu INTERFACE "-sASSERTIONS=2")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_WEBGPU_CPU_PROFILE)
|
||||
|
||||
@@ -9,6 +9,10 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-wgsl-shaders.hpp"
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
# include <emscripten/emscripten.h>
|
||||
#endif
|
||||
|
||||
#include <webgpu/webgpu_cpp.h>
|
||||
|
||||
#include <atomic>
|
||||
@@ -261,9 +265,12 @@ struct webgpu_context_struct {
|
||||
wgpu::Queue queue;
|
||||
wgpu::Limits limits;
|
||||
|
||||
uint32_t subgroup_size;
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
bool supports_subgroup_matrix = false;
|
||||
uint32_t subgroup_size;
|
||||
wgpu::SubgroupMatrixConfig subgroup_matrix_config;
|
||||
#endif
|
||||
|
||||
// Separate this out from limits since on some Metal systems, the limit returned by
|
||||
// querying the limits is higher than the actual allowed maximum.
|
||||
@@ -449,8 +456,8 @@ static void ggml_backend_webgpu_wait(webgpu_context & ct
|
||||
// If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
|
||||
// inflight_max may be 0, meaning that we must wait on all futures.
|
||||
uint64_t timeout_ms = block ? UINT64_MAX : 0;
|
||||
uint inflight_threads = ctx->inflight_threads;
|
||||
uint inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
|
||||
uint32_t inflight_threads = ctx->inflight_threads;
|
||||
uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
|
||||
while (futures.size() >= inflight_max && futures.size() > 0) {
|
||||
ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
|
||||
futures.erase(futures.begin());
|
||||
@@ -986,6 +993,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
|
||||
uint32_t wg_m;
|
||||
uint32_t wg_n;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (ctx->supports_subgroup_matrix) {
|
||||
// The total number of subgroups/workgroups needed per matrix.
|
||||
uint32_t wg_m_sg_tile =
|
||||
@@ -995,11 +1003,15 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
|
||||
wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile;
|
||||
} else {
|
||||
#endif
|
||||
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||
wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s;
|
||||
wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
}
|
||||
#endif
|
||||
|
||||
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
}
|
||||
}
|
||||
@@ -1419,9 +1431,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
commands.push_back(*cmd);
|
||||
}
|
||||
// compute the batch size based on the number of inflight threads
|
||||
uint inflight_threads = ctx->inflight_threads;
|
||||
uint batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
|
||||
WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
|
||||
uint32_t inflight_threads = ctx->inflight_threads;
|
||||
uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
|
||||
WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
|
||||
if (commands.size() >= batch_size) {
|
||||
futures.push_back(ggml_backend_webgpu_submit(ctx, commands));
|
||||
// Process events and check for completed submissions
|
||||
@@ -1758,6 +1770,17 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
|
||||
|
||||
std::string proc_mul_mat_f32_f32;
|
||||
std::string proc_mul_mat_f32_f32_vec;
|
||||
std::string proc_mul_mat_f16_f32;
|
||||
std::string proc_mul_mat_f16_f32_vec;
|
||||
std::string proc_mul_mat_f16_f16;
|
||||
std::string proc_mul_mat_f16_f16_vec;
|
||||
std::string proc_mul_mat_q4_0_f32;
|
||||
std::string proc_mul_mat_q4_0_f32_vec;
|
||||
|
||||
std::vector<wgpu::ConstantEntry> mul_mat_constants;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (webgpu_ctx->supports_subgroup_matrix) {
|
||||
std::map<std::string, std::string> sg_matrix_repls;
|
||||
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
|
||||
@@ -1770,100 +1793,57 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
|
||||
|
||||
std::string proc_mul_mat_subgroup_matrix_f32_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f32_f32_vec =
|
||||
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
|
||||
proc_mul_mat_f32_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f32_vec =
|
||||
proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
|
||||
proc_mul_mat_f16_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f16 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f16_vec =
|
||||
proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
|
||||
proc_mul_mat_f16_f16_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_q4_0_f32 =
|
||||
proc_mul_mat_q4_0_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec =
|
||||
proc_mul_mat_q4_0_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
|
||||
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_f32_f32_vec");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_f16_f32_vec");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_f16_f16_vec");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_q4_0_f32_vec");
|
||||
} else {
|
||||
std::vector<wgpu::ConstantEntry> mul_mat_reg_tile_constants(3);
|
||||
mul_mat_reg_tile_constants[0].key = "TILE_K";
|
||||
mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K;
|
||||
mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M";
|
||||
mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N";
|
||||
mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||
#endif
|
||||
mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
|
||||
mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
|
||||
mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
|
||||
|
||||
std::map<std::string, std::string> reg_repls;
|
||||
reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
|
||||
reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
|
||||
|
||||
// Process each reg-tile shader with tile replacements.
|
||||
// Keep the processed strings in-scope so .c_str() remains valid.
|
||||
std::string proc_mul_mat_reg_tile_f32_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f32_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f16 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f16_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_q4_0_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_q4_0_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
|
||||
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(),
|
||||
"mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(),
|
||||
"mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(),
|
||||
"mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(),
|
||||
"mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(),
|
||||
"mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(),
|
||||
"mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(),
|
||||
"mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(),
|
||||
"mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants);
|
||||
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
|
||||
proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
|
||||
proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
|
||||
proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
|
||||
proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
|
||||
proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
|
||||
proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
|
||||
proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
|
||||
#ifndef __EMSCRIPTEN__
|
||||
}
|
||||
#endif
|
||||
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
|
||||
|
||||
std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
|
||||
mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
|
||||
@@ -2384,13 +2364,17 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
|
||||
webgpu_context ctx = reg_ctx->webgpu_ctx;
|
||||
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
|
||||
const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
|
||||
wgpu::DawnTogglesDescriptor adapterTogglesDesc;
|
||||
adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
|
||||
adapterTogglesDesc.enabledToggleCount = 2;
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
options.nextInChain = &adapterTogglesDesc;
|
||||
#endif
|
||||
|
||||
ctx->instance.WaitAny(ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
@@ -2406,11 +2390,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
ctx->adapter.GetLimits(&ctx->limits);
|
||||
ctx->max_wg_size_x = 288; // default value
|
||||
|
||||
wgpu::AdapterInfo info{};
|
||||
wgpu::AdapterInfo info{};
|
||||
#ifndef __EMSCRIPTEN__
|
||||
wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
|
||||
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||
info.nextInChain = &subgroup_matrix_configs;
|
||||
}
|
||||
#endif
|
||||
ctx->adapter.GetInfo(&info);
|
||||
|
||||
wgpu::SupportedFeatures features;
|
||||
@@ -2418,6 +2404,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
// we require f16 support
|
||||
GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// Only support square f16 matrices of size 8 or 16 for now
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||
@@ -2433,36 +2420,27 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
}
|
||||
}
|
||||
|
||||
ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||
#endif
|
||||
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
||||
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
||||
ctx->subgroup_size = info.subgroupMaxSize;
|
||||
ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||
ctx->subgroup_size = info.subgroupMaxSize;
|
||||
|
||||
// Initialize device
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
||||
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
|
||||
if (ctx->supports_subgroup_matrix) {
|
||||
required_features.push_back(wgpu::FeatureName::Subgroups);
|
||||
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
required_features.push_back(wgpu::FeatureName::TimestampQuery);
|
||||
#endif
|
||||
|
||||
// Enable Dawn-specific toggles to increase native performance
|
||||
// TODO: Don't enable for WASM builds, they won't have an effect anyways
|
||||
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
|
||||
// only for native performance?
|
||||
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
|
||||
"disable_polyfills_on_integer_div_and_mod" };
|
||||
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
||||
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
||||
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
|
||||
deviceTogglesDesc.enabledToggleCount = 4;
|
||||
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
|
||||
deviceTogglesDesc.disabledToggleCount = 1;
|
||||
|
||||
wgpu::DeviceDescriptor dev_desc;
|
||||
dev_desc.requiredLimits = &ctx->limits;
|
||||
dev_desc.requiredFeatures = required_features.data();
|
||||
@@ -2480,7 +2458,23 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
||||
std::string(message).c_str());
|
||||
});
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// Enable Dawn-specific toggles to increase native performance
|
||||
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
|
||||
// only for native performance?
|
||||
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
|
||||
"disable_polyfills_on_integer_div_and_mod" };
|
||||
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
||||
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
||||
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
|
||||
deviceTogglesDesc.enabledToggleCount = 4;
|
||||
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
|
||||
deviceTogglesDesc.disabledToggleCount = 1;
|
||||
|
||||
dev_desc.nextInChain = &deviceTogglesDesc;
|
||||
#endif
|
||||
|
||||
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
|
||||
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
@@ -2578,18 +2572,27 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
ctx.name = GGML_WEBGPU_NAME;
|
||||
ctx.device_count = 1;
|
||||
|
||||
const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
|
||||
|
||||
wgpu::DawnTogglesDescriptor instanceTogglesDesc;
|
||||
instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
|
||||
instanceTogglesDesc.enabledToggleCount = 1;
|
||||
wgpu::InstanceDescriptor instance_descriptor{};
|
||||
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
||||
instance_descriptor.requiredFeatures = instance_features.data();
|
||||
instance_descriptor.requiredFeatureCount = instance_features.size();
|
||||
instance_descriptor.nextInChain = &instanceTogglesDesc;
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
|
||||
wgpu::DawnTogglesDescriptor instanceTogglesDesc;
|
||||
instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
|
||||
instanceTogglesDesc.enabledToggleCount = 1;
|
||||
instance_descriptor.nextInChain = &instanceTogglesDesc;
|
||||
#endif
|
||||
|
||||
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
if (webgpu_ctx->instance == nullptr) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
GGML_ASSERT(webgpu_ctx->instance != nullptr);
|
||||
|
||||
static ggml_backend_reg reg = {
|
||||
|
||||
@@ -1169,7 +1169,7 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo
|
||||
struct gguf_writer_base {
|
||||
size_t written_bytes {0u};
|
||||
|
||||
~gguf_writer_base(void) {}
|
||||
~gguf_writer_base(void) = default;
|
||||
|
||||
// we bet on devirtualization
|
||||
virtual void write(int8_t val) = 0;
|
||||
|
||||
@@ -31,6 +31,14 @@ except ImportError:
|
||||
else:
|
||||
_mistral_common_installed = True
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
|
||||
get_one_valid_tokenizer_file,
|
||||
)
|
||||
except ImportError:
|
||||
# We still want the conversion to work with older mistral-common versions.
|
||||
get_one_valid_tokenizer_file = None
|
||||
|
||||
|
||||
import gguf
|
||||
|
||||
@@ -673,24 +681,30 @@ class MistralVocab(Vocab):
|
||||
|
||||
# Find the tokenizer files
|
||||
all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
|
||||
valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
|
||||
|
||||
if len(valid_tokenizer_files) == 0:
|
||||
raise ValueError(f"No tokenizer file found in the directory: {base_path}")
|
||||
# If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
|
||||
if len(valid_tokenizer_files) > 1:
|
||||
if "tekken.json" in valid_tokenizer_files:
|
||||
tokenizer_file = "tekken.json"
|
||||
else:
|
||||
tokenizer_file = sorted(valid_tokenizer_files)[-1]
|
||||
logger.warning(
|
||||
f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
|
||||
)
|
||||
if get_one_valid_tokenizer_file is not None:
|
||||
tokenizer_file_path = get_one_valid_tokenizer_file(all_files)
|
||||
else:
|
||||
tokenizer_file = valid_tokenizer_files[0]
|
||||
valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
|
||||
|
||||
if len(valid_tokenizer_files) == 0:
|
||||
raise ValueError(f"No tokenizer file found in the directory: {base_path}")
|
||||
# If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
|
||||
if len(valid_tokenizer_files) > 1:
|
||||
if "tekken.json" in valid_tokenizer_files:
|
||||
tokenizer_file = "tekken.json"
|
||||
else:
|
||||
tokenizer_file = sorted(valid_tokenizer_files)[-1]
|
||||
logger.warning(
|
||||
f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
|
||||
)
|
||||
else:
|
||||
tokenizer_file = valid_tokenizer_files[0]
|
||||
|
||||
tokenizer_file_path = base_path / tokenizer_file
|
||||
|
||||
self.tokenizer = MistralTokenizer.from_file(
|
||||
base_path / tokenizer_file
|
||||
tokenizer_file_path
|
||||
).instruct_tokenizer.tokenizer
|
||||
self.tokenizer_type = (
|
||||
MistralTokenizerType.tekken
|
||||
@@ -698,7 +712,7 @@ class MistralVocab(Vocab):
|
||||
else MistralTokenizerType.spm
|
||||
)
|
||||
self.vocab_size = self.tokenizer.n_words
|
||||
self.fname_tokenizer = base_path / tokenizer_file
|
||||
self.fname_tokenizer = tokenizer_file_path
|
||||
self._name = (
|
||||
"mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
|
||||
)
|
||||
|
||||
110
scripts/serve-static.js
Normal file
110
scripts/serve-static.js
Normal file
@@ -0,0 +1,110 @@
|
||||
const http = require('http');
|
||||
const fs = require('fs').promises;
|
||||
const path = require('path');
|
||||
|
||||
// This file is used for testing wasm build from emscripten
|
||||
// Example build command:
|
||||
// emcmake cmake -B build-wasm -DGGML_WEBGPU=ON -DLLAMA_CURL=OFF
|
||||
// cmake --build build-wasm --target test-backend-ops -j
|
||||
|
||||
const PORT = 8080;
|
||||
const STATIC_DIR = path.join(__dirname, '../build-wasm/bin');
|
||||
console.log(`Serving static files from: ${STATIC_DIR}`);
|
||||
|
||||
const mimeTypes = {
|
||||
'.html': 'text/html',
|
||||
'.js': 'text/javascript',
|
||||
'.css': 'text/css',
|
||||
'.png': 'image/png',
|
||||
'.jpg': 'image/jpeg',
|
||||
'.gif': 'image/gif',
|
||||
'.svg': 'image/svg+xml',
|
||||
'.json': 'application/json',
|
||||
'.woff': 'font/woff',
|
||||
'.woff2': 'font/woff2',
|
||||
};
|
||||
|
||||
async function generateDirListing(dirPath, reqUrl) {
|
||||
const files = await fs.readdir(dirPath);
|
||||
let html = `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Directory Listing</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; padding: 20px; }
|
||||
ul { list-style: none; padding: 0; }
|
||||
li { margin: 5px 0; }
|
||||
a { text-decoration: none; color: #0066cc; }
|
||||
a:hover { text-decoration: underline; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Directory: ${reqUrl}</h1>
|
||||
<ul>
|
||||
`;
|
||||
|
||||
if (reqUrl !== '/') {
|
||||
html += `<li><a href="../">../ (Parent Directory)</a></li>`;
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
const filePath = path.join(dirPath, file);
|
||||
const stats = await fs.stat(filePath);
|
||||
const link = encodeURIComponent(file) + (stats.isDirectory() ? '/' : '');
|
||||
html += `<li><a href="${link}">${file}${stats.isDirectory() ? '/' : ''}</a></li>`;
|
||||
}
|
||||
|
||||
html += `
|
||||
</ul>
|
||||
</body>
|
||||
</html>
|
||||
`;
|
||||
return html;
|
||||
}
|
||||
|
||||
const server = http.createServer(async (req, res) => {
|
||||
try {
|
||||
// Set COOP and COEP headers
|
||||
res.setHeader('Cross-Origin-Opener-Policy', 'same-origin');
|
||||
res.setHeader('Cross-Origin-Embedder-Policy', 'require-corp');
|
||||
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate, proxy-revalidate');
|
||||
res.setHeader('Pragma', 'no-cache');
|
||||
res.setHeader('Expires', '0');
|
||||
|
||||
const filePath = path.join(STATIC_DIR, decodeURIComponent(req.url));
|
||||
const stats = await fs.stat(filePath);
|
||||
|
||||
if (stats.isDirectory()) {
|
||||
const indexPath = path.join(filePath, 'index.html');
|
||||
try {
|
||||
const indexData = await fs.readFile(indexPath);
|
||||
res.writeHeader(200, { 'Content-Type': 'text/html' });
|
||||
res.end(indexData);
|
||||
} catch {
|
||||
// No index.html, generate directory listing
|
||||
const dirListing = await generateDirListing(filePath, req.url);
|
||||
res.writeHeader(200, { 'Content-Type': 'text/html' });
|
||||
res.end(dirListing);
|
||||
}
|
||||
} else {
|
||||
const ext = path.extname(filePath).toLowerCase();
|
||||
const contentType = mimeTypes[ext] || 'application/octet-stream';
|
||||
const data = await fs.readFile(filePath);
|
||||
res.writeHeader(200, { 'Content-Type': contentType });
|
||||
res.end(data);
|
||||
}
|
||||
} catch (err) {
|
||||
if (err.code === 'ENOENT') {
|
||||
res.writeHeader(404, { 'Content-Type': 'text/plain' });
|
||||
res.end('404 Not Found');
|
||||
} else {
|
||||
res.writeHeader(500, { 'Content-Type': 'text/plain' });
|
||||
res.end('500 Internal Server Error');
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.listen(PORT, () => {
|
||||
console.log(`Server running at http://localhost:${PORT}/`);
|
||||
});
|
||||
@@ -37,7 +37,7 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
|
||||
template <typename T>
|
||||
struct no_init {
|
||||
T value;
|
||||
no_init() { /* do nothing */ }
|
||||
no_init() = default;
|
||||
};
|
||||
|
||||
struct time_meas {
|
||||
|
||||
@@ -423,8 +423,8 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s
|
||||
}
|
||||
|
||||
struct llama_model::impl {
|
||||
impl() {}
|
||||
~impl() {}
|
||||
impl() = default;
|
||||
~impl() = default;
|
||||
|
||||
uint64_t n_elements = 0;
|
||||
|
||||
@@ -461,7 +461,7 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi
|
||||
pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern;
|
||||
}
|
||||
|
||||
llama_model::~llama_model() {}
|
||||
llama_model::~llama_model() = default;
|
||||
|
||||
void llama_model::load_stats(llama_model_loader & ml) {
|
||||
pimpl->n_elements = ml.n_elements;
|
||||
|
||||
@@ -3253,8 +3253,7 @@ void llama_vocab::impl::print_info() const {
|
||||
llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
|
||||
}
|
||||
|
||||
llama_vocab::~llama_vocab() {
|
||||
}
|
||||
llama_vocab::~llama_vocab() = default;
|
||||
|
||||
void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
pimpl->load(ml, kv);
|
||||
|
||||
1
tests/.gitignore
vendored
1
tests/.gitignore
vendored
@@ -3,3 +3,4 @@
|
||||
*.o
|
||||
ggml-common.h
|
||||
**/*.swp
|
||||
!peg-parser
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
llama_add_compile_flags()
|
||||
|
||||
function(llama_build source)
|
||||
set(TEST_SOURCES ${source} ${ARGN})
|
||||
|
||||
if (DEFINED LLAMA_TEST_NAME)
|
||||
set(TEST_TARGET ${LLAMA_TEST_NAME})
|
||||
else()
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
endif()
|
||||
|
||||
add_executable(${TEST_TARGET} ${source})
|
||||
add_executable(${TEST_TARGET} ${TEST_SOURCES})
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE common)
|
||||
install(TARGETS ${TEST_TARGET} RUNTIME)
|
||||
endfunction()
|
||||
@@ -83,6 +85,8 @@ function(llama_build_and_test source)
|
||||
set(multiValueArgs ARGS)
|
||||
cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
set(TEST_SOURCES ${source} ${LLAMA_TEST_UNPARSED_ARGUMENTS} get-model.cpp)
|
||||
|
||||
if (NOT DEFINED LLAMA_TEST_LABEL)
|
||||
set(LLAMA_TEST_LABEL "main")
|
||||
endif()
|
||||
@@ -95,7 +99,7 @@ function(llama_build_and_test source)
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
endif()
|
||||
|
||||
add_executable(${TEST_TARGET} ${source} get-model.cpp)
|
||||
add_executable(${TEST_TARGET} ${TEST_SOURCES})
|
||||
install(TARGETS ${TEST_TARGET} RUNTIME)
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE common)
|
||||
|
||||
@@ -180,9 +184,21 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
|
||||
endif()
|
||||
|
||||
llama_build_and_test(test-chat-parser.cpp)
|
||||
llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp)
|
||||
llama_build_and_test(test-chat-template.cpp)
|
||||
llama_build_and_test(test-json-partial.cpp)
|
||||
llama_build_and_test(test-log.cpp)
|
||||
llama_build_and_test(
|
||||
test-peg-parser.cpp
|
||||
peg-parser/simple-tokenize.cpp
|
||||
peg-parser/test-basic.cpp
|
||||
peg-parser/test-gbnf-generation.cpp
|
||||
peg-parser/test-json-parser.cpp
|
||||
peg-parser/test-json-serialization.cpp
|
||||
peg-parser/test-unicode.cpp
|
||||
peg-parser/testing.h
|
||||
peg-parser/tests.h
|
||||
)
|
||||
llama_build_and_test(test-regex-partial.cpp)
|
||||
|
||||
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
||||
|
||||
37
tests/peg-parser/simple-tokenize.cpp
Normal file
37
tests/peg-parser/simple-tokenize.cpp
Normal file
@@ -0,0 +1,37 @@
|
||||
#include "simple-tokenize.h"
|
||||
|
||||
std::vector<std::string> simple_tokenize(const std::string & input) {
|
||||
std::vector<std::string> result;
|
||||
std::string current;
|
||||
|
||||
for (size_t i = 0; i < input.size(); i++) {
|
||||
switch (input[i]) {
|
||||
case ' ':
|
||||
case '\n':
|
||||
case '\t':
|
||||
case '{':
|
||||
case '}':
|
||||
case ',':
|
||||
case '[':
|
||||
case '"':
|
||||
case ']':
|
||||
case '.':
|
||||
case '<':
|
||||
case '>':
|
||||
case '=':
|
||||
case '/':
|
||||
if (!current.empty()) {
|
||||
result.push_back(current);
|
||||
current.clear();
|
||||
}
|
||||
default:;
|
||||
}
|
||||
current += input[i];
|
||||
}
|
||||
|
||||
if (!current.empty()) {
|
||||
result.push_back(current);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
6
tests/peg-parser/simple-tokenize.h
Normal file
6
tests/peg-parser/simple-tokenize.h
Normal file
@@ -0,0 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
std::vector<std::string> simple_tokenize(const std::string &);
|
||||
454
tests/peg-parser/test-basic.cpp
Normal file
454
tests/peg-parser/test-basic.cpp
Normal file
@@ -0,0 +1,454 @@
|
||||
#include "tests.h"
|
||||
|
||||
void test_basic(testing & t) {
|
||||
t.test("chars", [](testing & t) {
|
||||
// Test common escape sequences - newline
|
||||
t.test("escape_sequence_newline", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("\n");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_newline", true, result.success());
|
||||
});
|
||||
|
||||
// Test common escape sequences - tab
|
||||
t.test("escape_sequence_tab", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("\t");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_tab", true, result.success());
|
||||
});
|
||||
|
||||
// Test common escape sequences - backslash
|
||||
t.test("escape_sequence_backslash", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("\\");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_backslash", true, result.success());
|
||||
});
|
||||
|
||||
// Test common escape sequences - space (should ())
|
||||
t.test("escape_sequence_space_fail", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context(" ");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_space_fail", true, result.fail());
|
||||
});
|
||||
|
||||
// Test escaped dash - 'a' should succeed
|
||||
t.test("escaped_dash_a", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("a");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_a", true, result.success());
|
||||
});
|
||||
|
||||
// Test escaped dash - '-' should succeed (literal dash)
|
||||
t.test("escaped_dash_literal", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("-");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_literal", true, result.success());
|
||||
});
|
||||
|
||||
// Test escaped dash - 'z' should succeed
|
||||
t.test("escaped_dash_z", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("z");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_z", true, result.success());
|
||||
});
|
||||
|
||||
// Test escaped dash - 'b' should NOT match (since \- is literal dash, not range)
|
||||
t.test("escaped_dash_b_fail", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("b");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_b_fail", true, result.fail());
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
t.test("optional", [](testing & t) {
|
||||
// Full match with optional part present
|
||||
t.test("optional_present", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto ctx = common_peg_parse_context("hello world");
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("optional_present", true, result.success());
|
||||
t.assert_equal("optional_present_end", 11u, result.end);
|
||||
});
|
||||
|
||||
// Full match with optional part absent
|
||||
t.test("optional_absent", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto ctx = common_peg_parse_context("hello", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("optional_absent", true, result.success());
|
||||
t.assert_equal("optional_absent_end", 5u, result.end);
|
||||
});
|
||||
|
||||
// Partial match - waiting for more input to determine if optional matches
|
||||
t.test("partial_match_need_more", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto ctx = common_peg_parse_context("hello ", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("partial_match_need_more", true, result.need_more_input());
|
||||
});
|
||||
});
|
||||
|
||||
t.test("partial parsing", [](testing & t) {
|
||||
// Literals - Basic Success
|
||||
t.test("literal_success", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("hello"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("hello");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("literal_success", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes - Basic Lowercase Success
|
||||
t.test("char_class_lowercase_success", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("a");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_lowercase_success", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes - Uppercase Fail
|
||||
t.test("char_class_uppercase_fail", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("A");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_uppercase_fail", true, result.fail());
|
||||
});
|
||||
|
||||
// Char Classes with Dash - Lowercase Success
|
||||
t.test("char_class_with_dash_lowercase", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("f");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_with_dash_lowercase", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes with Dash - Literal Dash Success
|
||||
t.test("char_class_with_dash_literal_dash", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("-");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_with_dash_literal_dash", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes with Dash - Uppercase Fail
|
||||
t.test("char_class_with_dash_uppercase_fail", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("A");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_with_dash_uppercase_fail", true, result.fail());
|
||||
});
|
||||
|
||||
// Sequences - Partial Match 1
|
||||
t.test("sequence_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("<thi", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Sequences - Partial Match 2
|
||||
t.test("sequence_partial_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("begin") + p.literal("end"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("begin", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Sequences - Partial Match 3
|
||||
t.test("sequence_partial_match_3", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("<think></", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_partial_match_3", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Sequences - Full Match
|
||||
t.test("sequence_full_match", [&](testing & t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("hello") + p.literal("world"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("helloworld", false);
|
||||
auto result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("sequence_full_match", true, result.success());
|
||||
});
|
||||
|
||||
// Sequences - No Match
|
||||
t.test("sequence_no_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("<think>I am common_chat_combinator_parser", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_no_match", true, result.fail());
|
||||
});
|
||||
|
||||
// Choices - Partial Match 1
|
||||
t.test("choices_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("option1") | p.literal("option2"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("opt", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Choices - Partial Match 2
|
||||
t.test("choices_partial_match_2", [&](testing & t) {
|
||||
auto parser =
|
||||
build_peg_parser([](common_peg_parser_builder & p) { return p.literal("choice_a") | p.literal("choice_b"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("choice", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Choices - Full Match 1
|
||||
t.test("choices_full_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("first") | p.literal("second"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("first", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_full_match_1", true, result.success());
|
||||
});
|
||||
|
||||
// Choices - Full Match 2
|
||||
t.test("choices_full_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("alpha") | p.literal("beta"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("beta", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_full_match_2", true, result.success());
|
||||
});
|
||||
|
||||
// Choices - No Match
|
||||
t.test("choices_no_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("good") | p.literal("better"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("best", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_no_match", true, result.fail());
|
||||
});
|
||||
|
||||
// Zero or More - Partial Match 1
|
||||
t.test("zero_or_more_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("ab")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("a", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("zero_or_more_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Zero or More - Partial Match 2
|
||||
t.test("zero_or_more_partial_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("xy")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("xyx", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("zero_or_more_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Zero or More - Full Match
|
||||
t.test("zero_or_more_full_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("test")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("test", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("zero_or_more_full_match", true, result.success());
|
||||
});
|
||||
|
||||
// One or More - Partial Match 1
|
||||
t.test("one_or_more_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("repeat")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("rep", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// One or More - Partial Match 2
|
||||
t.test("one_or_more_partial_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("ab")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("aba", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// One or More - Full Match
|
||||
t.test("one_or_more_full_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("single")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("single", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_full_match", true, result.success());
|
||||
});
|
||||
|
||||
// One or More - No Match
|
||||
t.test("one_or_more_no_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("()")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("success", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_no_match", true, result.fail());
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
t.test("recursive rules", [](testing &t) {
|
||||
// Test simple number
|
||||
t.test("simple_number", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("1", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test simple list
|
||||
t.test("simple_list", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[1]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test nested list
|
||||
t.test("nested_list", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[[2]]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test deeply nested list
|
||||
t.test("deeply_nested_list", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[[[3]]]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test need_more_input match
|
||||
t.test("need_more_input_match", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[[", true);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Test no match
|
||||
t.test("no_match", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[a]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_fail", true, result.fail());
|
||||
});
|
||||
});
|
||||
}
|
||||
250
tests/peg-parser/test-gbnf-generation.cpp
Normal file
250
tests/peg-parser/test-gbnf-generation.cpp
Normal file
@@ -0,0 +1,250 @@
|
||||
#include "tests.h"
|
||||
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
||||
#include <regex>
|
||||
|
||||
static std::string trim_leading_space(const std::string & s) {
|
||||
static const std::regex leading_ws_re = std::regex(R"((^|\n)\s+)");
|
||||
return std::regex_replace(s, leading_ws_re, "$1");
|
||||
}
|
||||
|
||||
static void assert_gbnf_equal(testing & t, const std::string & expected, const std::string & actual) {
|
||||
t.assert_equal("gbnf are equal", trim_leading_space(expected), trim_leading_space(actual));
|
||||
}
|
||||
|
||||
void test_gbnf_generation(testing &t) {
|
||||
t.test("literal grammar generation", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("char class grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.chars("[a-z]", 1, 1);
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= [a-z]
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("sequence grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.literal(" ") + p.literal("world");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello" " " "world"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("choice grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("cat") | p.literal("dog");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "cat" | "dog"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("one_or_more grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.one_or_more(p.literal("a"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a"+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("zero_or_more grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.zero_or_more(p.literal("a"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a"*
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("optional grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello" " world"?
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("until grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= ([^<] | "<" [^/] | "</" [^t] | "</t" [^a] | "</ta" [^g] | "</tag" [^>])*
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("complex expressions with parentheses", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.one_or_more(p.literal("a") | p.literal("b"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= ("a" | "b")+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("rule references", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
auto digit = p.rule("digit", p.chars("[0-9]", 1, 1));
|
||||
return p.one_or_more(digit);
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
digit ::= [0-9]
|
||||
root ::= digit+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("escaping in literals", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello\nworld\n!");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello\nworld\n!"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("operator<< (whitespace insertion)", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") << p.literal("world");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello" space "world"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("emit only reachable rules", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("orphan", p.literal("orphan"));
|
||||
return p.literal("hello") + p.rule("child", p.literal(" world"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
child ::= " world"
|
||||
root ::= "hello" child
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("emit only trigger rules (and references)", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
auto rule1 = p.rule("rule-1", p.literal("a") + p.ref("rule-2"));
|
||||
p.rule("rule-2", p.literal("b") + p.ref("rule-3"), true);
|
||||
p.rule("rule-3", p.literal("c") + p.ref("rule-4"));
|
||||
p.rule("rule-4", p.literal("d"), true);
|
||||
return rule1;
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= rule-1
|
||||
rule-1 ::= "a" rule-2
|
||||
rule-2 ::= "b" rule-3
|
||||
rule-3 ::= "c" rule-4
|
||||
rule-4 ::= "d"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
|
||||
auto gbnf_lazy = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder, true);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= rule-2 | rule-4
|
||||
rule-2 ::= "b" rule-3
|
||||
rule-3 ::= "c" rule-4
|
||||
rule-4 ::= "d"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf_lazy);
|
||||
});
|
||||
}
|
||||
109
tests/peg-parser/test-json-parser.cpp
Normal file
109
tests/peg-parser/test-json-parser.cpp
Normal file
@@ -0,0 +1,109 @@
|
||||
#include "tests.h"
|
||||
|
||||
void test_json_parser(testing &t) {
|
||||
// Test parsing a simple JSON object
|
||||
t.test("simple JSON object parsing", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"({"name": "test", "value": 42, "flag": true})";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test parsing a JSON array with mixed types
|
||||
t.test("JSON array with mixed types", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"([1, "hello", true, null, 3.14])";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test parsing nested JSON with objects and arrays
|
||||
t.test("nested JSON with objects and arrays", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input =
|
||||
R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test need_more_input() parsing - incomplete object
|
||||
t.test("need_more_input() parsing - incomplete object", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"({"name": "test", "value": )";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Test need_more_input() parsing - incomplete array
|
||||
t.test("need_more_input() parsing - incomplete array", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"([1, 2, 3, )";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Test need_more_input() parsing - incomplete nested structure
|
||||
t.test("need_more_input() parsing - incomplete nested structure", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"({"data": {"nested": )";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
t.test("object member", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.json_member("name", "\"" + p.chars("[a-z]") + "\"");
|
||||
});
|
||||
|
||||
t.test("success", [&](testing &t) {
|
||||
std::string input = R"("name": "bob")";
|
||||
common_peg_parse_context ctx(input, false);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
});
|
||||
|
||||
t.test("partial", [&](testing &t) {
|
||||
std::string input = R"("name": "bo)";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("need more input", result.need_more_input());
|
||||
});
|
||||
|
||||
t.test("failed", [&](testing &t) {
|
||||
std::string input = R"([])";
|
||||
common_peg_parse_context ctx(input, false);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("fail", result.fail());
|
||||
});
|
||||
});
|
||||
}
|
||||
28
tests/peg-parser/test-json-serialization.cpp
Normal file
28
tests/peg-parser/test-json-serialization.cpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#include "tests.h"
|
||||
|
||||
void test_json_serialization(testing &t) {
|
||||
auto original = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return "<tool_call>" + p.json() + "</tool_call>";
|
||||
});
|
||||
|
||||
auto json_serialized = original.to_json().dump();
|
||||
|
||||
t.test("compare before/after", [&](testing &t) {
|
||||
auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized));
|
||||
|
||||
// Test complex JSON
|
||||
std::string input = R"({"name": "test", "values": [1, 2, 3], "nested": {"a": true}})";
|
||||
common_peg_parse_context ctx1(input);
|
||||
common_peg_parse_context ctx2(input);
|
||||
|
||||
auto result1 = original.parse(ctx1);
|
||||
auto result2 = deserialized.parse(ctx2);
|
||||
|
||||
t.assert_equal("both_succeed", result1.success(), result2.success());
|
||||
t.assert_equal("same_end_pos", result1.end, result2.end);
|
||||
});
|
||||
|
||||
t.bench("deserialize", [&]() {
|
||||
auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized));
|
||||
}, 100);
|
||||
}
|
||||
449
tests/peg-parser/test-unicode.cpp
Normal file
449
tests/peg-parser/test-unicode.cpp
Normal file
@@ -0,0 +1,449 @@
|
||||
#include "tests.h"
|
||||
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include <cctype>
|
||||
|
||||
static void assert_result_equal(testing & t, common_peg_parse_result_type expected, common_peg_parse_result_type actual) {
|
||||
t.assert_equal(common_peg_parse_result_type_name(expected), common_peg_parse_result_type_name(actual));
|
||||
}
|
||||
|
||||
static std::string hex_dump(const std::string& str) {
|
||||
std::ostringstream oss;
|
||||
for (unsigned char c : str) {
|
||||
if (std::isprint(c)) {
|
||||
oss << c;
|
||||
} else {
|
||||
oss << "\\x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(c);
|
||||
}
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
void test_unicode(testing &t) {
|
||||
struct test_case {
|
||||
std::string input;
|
||||
std::string expected_text;
|
||||
common_peg_parse_result_type expected_result;
|
||||
};
|
||||
|
||||
t.test("any", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Valid UTF-8 sequences
|
||||
{"Hello", "Hello", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
{std::string("Caf\xC3\xA9"), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
{std::string("\xF0\x9F\x9A\x80"), std::string("\xF0\x9F\x9A\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Incomplete UTF-8 sequences (partial bytes at end)
|
||||
{std::string("Caf\xC3"), "Caf", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
{std::string("\xE4\xBD"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
{std::string("\xF0\x9F\x9A"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Invalid/malformed UTF-8 sequences
|
||||
{std::string("\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
{std::string("Hello\x80World"), "Hello", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.one_or_more(p.any()), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("char classes", [](testing &t) {
|
||||
t.test("unicode range U+4E00-U+9FFF (CJK)", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Within range - CJK Unified Ideographs
|
||||
{std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00
|
||||
{std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60
|
||||
{std::string("\xE5\xA5\xBD"), std::string("\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+597D
|
||||
{std::string("\xE9\xBF\xBF"), std::string("\xE9\xBF\xBF"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+9FFF
|
||||
|
||||
// Outside range - should fail
|
||||
{"a", "", COMMON_PEG_PARSE_RESULT_FAIL}, // ASCII
|
||||
{std::string("\xE4\xB7\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+4DFF (before range)
|
||||
{std::string("\xEA\x80\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+A000 (after range)
|
||||
|
||||
// Incomplete sequences in range
|
||||
{std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+4E00
|
||||
{std::string("\xE5\xA5"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+597D
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.chars(R"([\u4E00-\u9FFF])"), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("unicode range U+1F600-U+1F64F (emoticons)", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Within range - Emoticons (all 4-byte UTF-8)
|
||||
{std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600
|
||||
{std::string("\xF0\x9F\x98\x81"), std::string("\xF0\x9F\x98\x81"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F601
|
||||
{std::string("\xF0\x9F\x99\x8F"), std::string("\xF0\x9F\x99\x8F"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F64F
|
||||
|
||||
// Outside range
|
||||
{std::string("\xF0\x9F\x97\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F5FF (before range)
|
||||
{std::string("\xF0\x9F\x99\x90"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F650 (after range)
|
||||
{std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680 (outside range)
|
||||
|
||||
// Incomplete sequences
|
||||
{std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete emoji
|
||||
{std::string("\xF0\x9F"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Very incomplete
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.chars(R"([\U0001F600-\U0001F64F])"), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("mixed unicode ranges", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Match CJK
|
||||
{std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00
|
||||
{std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60
|
||||
|
||||
// Match emoticons
|
||||
{std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600
|
||||
|
||||
// Match ASCII digits
|
||||
{"5", "5", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Don't match outside any range
|
||||
{"a", "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
{std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680
|
||||
|
||||
// Incomplete
|
||||
{std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
{std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.chars(R"([\u4E00-\u9FFF\U0001F600-\U0001F64F0-9])"), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
t.test("until parser", [](testing &t) {
|
||||
t.test("ASCII delimiter with Unicode content", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// CJK characters before delimiter
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD</tag>"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Emoji before delimiter
|
||||
{std::string("\xF0\x9F\x98\x80</tag>"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Mixed content
|
||||
{std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!</tag>"), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("incomplete UTF-8 at end", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Incomplete emoji at end, no delimiter
|
||||
{std::string("content\xF0\x9F\x98"), std::string("content"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete CJK at end, no delimiter
|
||||
{std::string("hello\xE4\xB8"), std::string("hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Complete content, no delimiter (should consume all valid UTF-8)
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("malformed UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Invalid UTF-8 bytes
|
||||
{std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Continuation byte without lead byte
|
||||
{std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Invalid continuation byte
|
||||
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
t.test("json_string parser", [](testing &t) {
|
||||
t.test("valid UTF-8 characters", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// ASCII only
|
||||
{"Hello World\"", "Hello World", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// 2-byte UTF-8 (accented characters)
|
||||
{std::string("Caf\xC3\xA9\""), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// 3-byte UTF-8 (CJK)
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// 4-byte UTF-8 (emoji)
|
||||
{std::string("\xF0\x9F\x98\x80\""), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Mixed content
|
||||
{std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!\""), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.json_string_content(), p.literal("\"")});
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("incomplete UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Incomplete 2-byte sequence
|
||||
{std::string("Caf\xC3"), std::string("Caf"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete 3-byte sequence
|
||||
{std::string("Hello\xE4\xB8"), std::string("Hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete 4-byte sequence
|
||||
{std::string("Text\xF0\x9F\x98"), std::string("Text"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete at very start
|
||||
{std::string("\xE4\xBD"), std::string(""), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.json_string_content();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("malformed UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Invalid UTF-8 bytes
|
||||
{std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Continuation byte without lead byte
|
||||
{std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Invalid continuation byte
|
||||
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Overlong encoding (security issue)
|
||||
{std::string("\xC0\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.json_string_content();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("escape sequences with UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Unicode escape sequence
|
||||
{"Hello\\u0041\"", "Hello\\u0041", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Mix of UTF-8 and escape sequences
|
||||
{std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Escaped quote in UTF-8 string
|
||||
{std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.json_string_content(), p.literal("\"")});
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
243
tests/peg-parser/testing.h
Normal file
243
tests/peg-parser/testing.h
Normal file
@@ -0,0 +1,243 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include <vector>
|
||||
|
||||
struct testing {
|
||||
std::ostream &out;
|
||||
std::vector<std::string> stack;
|
||||
std::regex filter;
|
||||
bool filter_tests = false;
|
||||
bool throw_exception = false;
|
||||
bool verbose = false;
|
||||
int tests = 0;
|
||||
int assertions = 0;
|
||||
int failures = 0;
|
||||
int unnamed = 0;
|
||||
int exceptions = 0;
|
||||
|
||||
static constexpr std::size_t status_column = 80;
|
||||
|
||||
explicit testing(std::ostream &os = std::cout) : out(os) {}
|
||||
|
||||
std::string indent() const {
|
||||
if (stack.empty()) {
|
||||
return "";
|
||||
}
|
||||
return std::string((stack.size() - 1) * 2, ' ');
|
||||
}
|
||||
|
||||
std::string full_name() const {
|
||||
return string_join(stack, ".");
|
||||
}
|
||||
|
||||
void log(const std::string & msg) {
|
||||
if (verbose) {
|
||||
out << indent() << " " << msg << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void set_filter(const std::string & re) {
|
||||
filter = std::regex(re);
|
||||
filter_tests = true;
|
||||
}
|
||||
|
||||
bool should_run() const {
|
||||
if (filter_tests) {
|
||||
if (!std::regex_match(full_name(), filter)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void run_with_exceptions(F &&f, const char *ctx) {
|
||||
try {
|
||||
f();
|
||||
} catch (const std::exception &e) {
|
||||
++failures;
|
||||
++exceptions;
|
||||
out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): " << e.what() << "\n";
|
||||
if (throw_exception) {
|
||||
throw;
|
||||
}
|
||||
} catch (...) {
|
||||
++failures;
|
||||
++exceptions;
|
||||
out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): unknown\n";
|
||||
if (throw_exception) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void print_result(const std::string &label, int new_failures, int new_assertions, const std::string &extra = "") const {
|
||||
std::string line = indent() + label;
|
||||
|
||||
std::string details;
|
||||
if (new_assertions > 0) {
|
||||
if (new_failures == 0) {
|
||||
details = std::to_string(new_assertions) + " assertion(s)";
|
||||
} else {
|
||||
details = std::to_string(new_failures) + " of " +
|
||||
std::to_string(new_assertions) + " assertion(s) failed";
|
||||
}
|
||||
}
|
||||
if (!extra.empty()) {
|
||||
if (!details.empty()) {
|
||||
details += ", ";
|
||||
}
|
||||
details += extra;
|
||||
}
|
||||
|
||||
if (!details.empty()) {
|
||||
line += " (" + details + ")";
|
||||
}
|
||||
|
||||
std::string status = (new_failures == 0) ? "[PASS]" : "[FAIL]";
|
||||
|
||||
if (line.size() + 1 < status_column) {
|
||||
line.append(status_column - line.size(), ' ');
|
||||
} else {
|
||||
line.push_back(' ');
|
||||
}
|
||||
|
||||
out << line << status << "\n";
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void test(const std::string &name, F f) {
|
||||
stack.push_back(name);
|
||||
if (!should_run()) {
|
||||
stack.pop_back();
|
||||
return;
|
||||
}
|
||||
|
||||
++tests;
|
||||
out << indent() << name << "\n";
|
||||
|
||||
int before_failures = failures;
|
||||
int before_assertions = assertions;
|
||||
|
||||
run_with_exceptions([&] { f(*this); }, "test");
|
||||
|
||||
int new_failures = failures - before_failures;
|
||||
int new_assertions = assertions - before_assertions;
|
||||
|
||||
print_result(name, new_failures, new_assertions);
|
||||
|
||||
stack.pop_back();
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void test(F f) {
|
||||
test("test #" + std::to_string(++unnamed), f);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void bench(const std::string &name, F f, int iterations = 100) {
|
||||
stack.push_back(name);
|
||||
if (!should_run()) {
|
||||
stack.pop_back();
|
||||
return;
|
||||
}
|
||||
|
||||
++tests;
|
||||
out << indent() << "[bench] " << name << "\n";
|
||||
|
||||
int before_failures = failures;
|
||||
int before_assertions = assertions;
|
||||
|
||||
using clock = std::chrono::high_resolution_clock;
|
||||
|
||||
std::chrono::microseconds duration(0);
|
||||
|
||||
run_with_exceptions([&] {
|
||||
for (auto i = 0; i < iterations; i++) {
|
||||
auto start = clock::now();
|
||||
f();
|
||||
duration += std::chrono::duration_cast<std::chrono::microseconds>(clock::now() - start);
|
||||
}
|
||||
}, "bench");
|
||||
|
||||
auto avg_elapsed = duration.count() / iterations;
|
||||
auto avg_elapsed_s = std::chrono::duration_cast<std::chrono::duration<double>>(duration).count() / iterations;
|
||||
auto rate = (avg_elapsed_s > 0.0) ? (1.0 / avg_elapsed_s) : 0.0;
|
||||
|
||||
int new_failures = failures - before_failures;
|
||||
int new_assertions = assertions - before_assertions;
|
||||
|
||||
std::string extra =
|
||||
"n=" + std::to_string(iterations) +
|
||||
" avg=" + std::to_string(avg_elapsed) + "us" +
|
||||
" rate=" + std::to_string(int(rate)) + "/s";
|
||||
|
||||
print_result("[bench] " + name, new_failures, new_assertions, extra);
|
||||
|
||||
stack.pop_back();
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void bench(F f, int iterations = 100) {
|
||||
bench("bench #" + std::to_string(++unnamed), f, iterations);
|
||||
}
|
||||
|
||||
// Assertions
|
||||
bool assert_true(bool cond) {
|
||||
return assert_true("", cond);
|
||||
}
|
||||
|
||||
bool assert_true(const std::string &msg, bool cond) {
|
||||
++assertions;
|
||||
if (!cond) {
|
||||
++failures;
|
||||
out << indent() << "ASSERT TRUE FAILED";
|
||||
if (!msg.empty()) {
|
||||
out << " : " << msg;
|
||||
}
|
||||
out << "\n";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename A, typename B>
|
||||
bool assert_equal(const A &expected, const B &actual) {
|
||||
return assert_equal("", expected, actual);
|
||||
}
|
||||
|
||||
template <typename A, typename B>
|
||||
bool assert_equal(const std::string &msg, const A &expected, const B &actual) {
|
||||
++assertions;
|
||||
if (!(actual == expected)) {
|
||||
++failures;
|
||||
out << indent() << "ASSERT EQUAL FAILED";
|
||||
if (!msg.empty()) {
|
||||
out << " : " << msg;
|
||||
}
|
||||
out << "\n";
|
||||
|
||||
out << indent() << " expected: " << expected << "\n";
|
||||
out << indent() << " actual : " << actual << "\n";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Print summary and return an exit code
|
||||
int summary() const {
|
||||
out << "\n";
|
||||
out << "tests : " << tests << "\n";
|
||||
out << "assertions : " << assertions << "\n";
|
||||
out << "failures : " << failures << "\n";
|
||||
out << "exceptions : " << exceptions << "\n";
|
||||
return failures == 0 ? 0 : 1;
|
||||
}
|
||||
};
|
||||
24
tests/peg-parser/tests.h
Normal file
24
tests/peg-parser/tests.h
Normal file
@@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
// Common includes for all test files
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "testing.h"
|
||||
#include "peg-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "simple-tokenize.h"
|
||||
|
||||
struct bench_tool_call {
|
||||
std::string id;
|
||||
std::string name;
|
||||
nlohmann::ordered_json args;
|
||||
};
|
||||
|
||||
// Test function declarations
|
||||
void test_basic(testing &t);
|
||||
void test_json_parser(testing &t);
|
||||
void test_gbnf_generation(testing &t);
|
||||
void test_unicode(testing &t);
|
||||
void test_json_serialization(testing &t);
|
||||
@@ -41,12 +41,18 @@
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
# define N_THREADS 1
|
||||
#else
|
||||
# define N_THREADS std::thread::hardware_concurrency()
|
||||
#endif
|
||||
|
||||
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
|
||||
size_t nels = ggml_nelements(tensor);
|
||||
std::vector<float> data(nels);
|
||||
{
|
||||
// parallel initialization
|
||||
static const size_t n_threads = std::thread::hardware_concurrency();
|
||||
static const size_t n_threads = N_THREADS;
|
||||
// static RNG initialization (revisit if n_threads stops being constant)
|
||||
static std::vector<std::default_random_engine> generators = []() {
|
||||
std::random_device rd;
|
||||
@@ -65,15 +71,19 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::future<void>> tasks;
|
||||
tasks.reserve(n_threads);
|
||||
for (size_t i = 0; i < n_threads; i++) {
|
||||
size_t start = i*nels/n_threads;
|
||||
size_t end = (i+1)*nels/n_threads;
|
||||
tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
|
||||
}
|
||||
for (auto & t : tasks) {
|
||||
t.get();
|
||||
if (n_threads == 1) {
|
||||
init_thread(0, 0, nels);
|
||||
} else {
|
||||
std::vector<std::future<void>> tasks;
|
||||
tasks.reserve(n_threads);
|
||||
for (size_t i = 0; i < n_threads; i++) {
|
||||
size_t start = i*nels/n_threads;
|
||||
size_t end = (i+1)*nels/n_threads;
|
||||
tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
|
||||
}
|
||||
for (auto & t : tasks) {
|
||||
t.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,17 +115,23 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
||||
};
|
||||
|
||||
const size_t min_blocks_per_thread = 1;
|
||||
const size_t n_threads = std::min<size_t>(std::thread::hardware_concurrency()/2,
|
||||
std::max<size_t>(1, n_blocks / min_blocks_per_thread));
|
||||
std::vector<std::future<void>> tasks;
|
||||
tasks.reserve(n_threads);
|
||||
for (size_t i = 0; i < n_threads; i++) {
|
||||
size_t start = i*n_blocks/n_threads;
|
||||
size_t end = (i+1)*n_blocks/n_threads;
|
||||
tasks.push_back(std::async(std::launch::async, quantize_thread, start, end));
|
||||
}
|
||||
for (auto & t : tasks) {
|
||||
t.get();
|
||||
const size_t n_quant_threads = std::min<size_t>(std::max<size_t>(N_THREADS/2, 1),
|
||||
std::max<size_t>(1, n_blocks / min_blocks_per_thread));
|
||||
|
||||
if (n_quant_threads == 1) {
|
||||
// single-threaded quantization: do all blocks in the current thread
|
||||
quantize_thread(0, n_blocks);
|
||||
} else {
|
||||
std::vector<std::future<void>> tasks;
|
||||
tasks.reserve(n_quant_threads);
|
||||
for (size_t i = 0; i < n_quant_threads; i++) {
|
||||
size_t start = i*n_blocks/n_quant_threads;
|
||||
size_t end = (i+1)*n_blocks/n_quant_threads;
|
||||
tasks.push_back(std::async(std::launch::async, quantize_thread, start, end));
|
||||
}
|
||||
for (auto & t : tasks) {
|
||||
t.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
|
||||
@@ -8363,7 +8379,7 @@ int main(int argc, char ** argv) {
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
if (ggml_backend_set_n_threads_fn) {
|
||||
// TODO: better value for n_threads
|
||||
ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency());
|
||||
ggml_backend_set_n_threads_fn(backend, N_THREADS);
|
||||
}
|
||||
|
||||
size_t free, total; // NOLINT
|
||||
|
||||
768
tests/test-chat-peg-parser.cpp
Normal file
768
tests/test-chat-peg-parser.cpp
Normal file
@@ -0,0 +1,768 @@
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
#include "chat-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "peg-parser.h"
|
||||
#include "peg-parser/testing.h"
|
||||
#include "peg-parser/simple-tokenize.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static json create_tools();
|
||||
static void test_example_native(testing & t);
|
||||
static void test_example_qwen3_coder(testing & t);
|
||||
static void test_command7_parser_compare(testing & t);
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
testing t(std::cout);
|
||||
if (argc >= 2) {
|
||||
t.set_filter(argv[1]);
|
||||
}
|
||||
|
||||
const char * verbose = getenv("LLAMA_TEST_VERBOSE");
|
||||
if (verbose) {
|
||||
t.verbose = std::string(verbose) == "1";
|
||||
}
|
||||
|
||||
t.test("native", test_example_native);
|
||||
t.test("qwen3 coder", test_example_qwen3_coder);
|
||||
t.test("comparison", test_command7_parser_compare);
|
||||
|
||||
return t.summary();
|
||||
}
|
||||
|
||||
static json create_tools() {
|
||||
json tools = json::array();
|
||||
|
||||
json tool_weather = {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "get_current_weather"},
|
||||
{"description", "Get the current weather in a given location"},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"location", {
|
||||
{"type", "string"},
|
||||
{"description", "The city and state, e.g. San Francisco, CA"}
|
||||
}},
|
||||
{"unit", {
|
||||
{"type", "string"},
|
||||
{"enum", {"celsius", "fahrenheit"}},
|
||||
{"description", "The temperature unit to use. Infer this from the users location."}
|
||||
}}
|
||||
}},
|
||||
{"required", {"location", "unit"}},
|
||||
}},
|
||||
}}
|
||||
};
|
||||
tools.push_back(tool_weather);
|
||||
|
||||
json tool_forecast = {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "get_forecast"},
|
||||
{"description", "Get the weather forecast for a given location"},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"location", {
|
||||
{"type", "string"},
|
||||
{"description", "The city and state, e.g. San Francisco, CA"}
|
||||
}},
|
||||
{"unit", {
|
||||
{"type", "string"},
|
||||
{"enum", {"celsius", "fahrenheit"}},
|
||||
{"description", "The temperature unit to use. Infer this from the users location."}
|
||||
}},
|
||||
{"days", {
|
||||
{"type", "integer"},
|
||||
{"description", "Number of days to forecast (1-10)"},
|
||||
{"minimum", 1},
|
||||
{"maximum", 10}
|
||||
}}
|
||||
}},
|
||||
{"required", {"location", "unit"}},
|
||||
}},
|
||||
}}
|
||||
};
|
||||
tools.push_back(tool_forecast);
|
||||
|
||||
json tool_search = {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "search_knowledge_base"},
|
||||
{"description", "Search the internal technical documentation knowledge base."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"query", {
|
||||
{"type", "string"},
|
||||
{"description", "The search query string."}
|
||||
}},
|
||||
{"max_results", {
|
||||
{"type", "integer"},
|
||||
{"description", "The maximum number of results to return."},
|
||||
{"default", 5}
|
||||
}},
|
||||
{"category", {
|
||||
{"type", "string"},
|
||||
{"enum", {"api", "troubleshooting", "billing", "general"}},
|
||||
{"description", "Filter search by specific category."}
|
||||
}}
|
||||
}},
|
||||
{"required", {"query", "category"}},
|
||||
{"additionalProperties", false}
|
||||
}},
|
||||
{"strict", true}
|
||||
}}
|
||||
};
|
||||
tools.push_back(tool_search);
|
||||
|
||||
return tools;
|
||||
}
|
||||
|
||||
struct tool_argument {
|
||||
std::string name;
|
||||
std::string type;
|
||||
bool is_required;
|
||||
json schema;
|
||||
};
|
||||
|
||||
struct tool_definition {
|
||||
std::string name;
|
||||
std::vector<tool_argument> arguments;
|
||||
json schema;
|
||||
};
|
||||
|
||||
// Test fictitious model output that emits arguments as JSON.
|
||||
static void test_example_native(testing & t) {
|
||||
struct test_case {
|
||||
// Parameters
|
||||
std::string name;
|
||||
json tools;
|
||||
common_chat_tool_choice tool_choice;
|
||||
common_reasoning_format reasoning_format;
|
||||
json json_schema;
|
||||
bool parallel_tool_calls;
|
||||
bool thinking_forced_open;
|
||||
std::string input;
|
||||
|
||||
// Expect
|
||||
std::string expect_reasoning;
|
||||
std::string expect_content;
|
||||
std::vector<common_chat_tool_call> expect_tool_calls;
|
||||
};
|
||||
|
||||
auto build_parser = [](const test_case & tc) {
|
||||
return build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
|
||||
auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE);
|
||||
auto reasoning = p.eps();
|
||||
if (tc.thinking_forced_open) {
|
||||
// If thinking is forced open, expect a closing tag
|
||||
reasoning = p.reasoning(p.until("</think>")) + "</think>" + p.space();
|
||||
} else {
|
||||
// Otherwise, optionally accept thinking wrapped in tags
|
||||
reasoning = p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>" + p.space());
|
||||
}
|
||||
|
||||
// tool calling parser
|
||||
if (tc.tools.is_array() && !tc.tools.empty()) {
|
||||
auto tools = p.choice();
|
||||
for (const auto & tool : tc.tools) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
const auto & schema = function.at("parameters");
|
||||
|
||||
auto tool_name = p.json_member("name", "\"" + p.tool_name(p.literal(name)) + "\"");
|
||||
auto tool_args = p.json_member("arguments", p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
|
||||
|
||||
tools |= p.rule("tool-" + name, p.tool_open(p.literal("{")) << tool_name << "," << tool_args << "}");
|
||||
};
|
||||
|
||||
auto parallel_calls = p.eps();
|
||||
if (tc.parallel_tool_calls) {
|
||||
parallel_calls = p.zero_or_more("," << tools);
|
||||
}
|
||||
|
||||
auto tool_call = p.trigger_rule("tool-call",
|
||||
p.sequence({
|
||||
p.literal("<tool_call>["),
|
||||
tools,
|
||||
parallel_calls,
|
||||
p.literal("]</tool_call>")
|
||||
})
|
||||
);
|
||||
|
||||
return p.sequence({
|
||||
(reasoning_in_content ? p.eps() : reasoning),
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.optional(p.space() + tool_call),
|
||||
p.space(),
|
||||
p.end()
|
||||
});
|
||||
}
|
||||
|
||||
// response_format parser
|
||||
if (tc.json_schema.is_object() && !tc.json_schema.empty()) {
|
||||
return p.sequence({
|
||||
(reasoning_in_content ? p.eps() : reasoning),
|
||||
p.content(p.schema(p.json(), "response-output", tc.json_schema)),
|
||||
p.space(),
|
||||
p.end()
|
||||
});
|
||||
}
|
||||
|
||||
// Content-only parser
|
||||
return p.sequence({
|
||||
(reasoning_in_content ? p.eps() : reasoning),
|
||||
p.content(p.rest()),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
std::vector<test_case> test_cases = std::vector<test_case>{
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = false",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
/* .input = */ (
|
||||
"<think>The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "The user said hello, I must say hello back",
|
||||
/* .expect_content = */ "Hello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = false and no reasoning",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
/* .input = */ (
|
||||
"Hello"
|
||||
),
|
||||
/* .expect_reasoning = */ "",
|
||||
/* .expect_content = */ "Hello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = false and reasoning_format = none",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"<think>The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "",
|
||||
/* .expect_content = */ "<think>The user said hello, I must say hello back</think>\nHello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = true",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "The user said hello, I must say hello back",
|
||||
/* .expect_content = */ "Hello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = true and reasoning_format = none",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "",
|
||||
/* .expect_content = */ "The user said hello, I must say hello back</think>\nHello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "tools with tool_choice = auto and no parallel_tool_calls",
|
||||
/* .tools = */ create_tools(),
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"I must get the weather in New York</think>\n"
|
||||
"<tool_call>["
|
||||
R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})"
|
||||
"]</tool_call>"
|
||||
),
|
||||
/* .expect_reasoning = */ "I must get the weather in New York",
|
||||
/* .expect_content = */ "",
|
||||
/* .expect_tool_calls = */ {{
|
||||
/* .name = */ "get_current_weather",
|
||||
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})",
|
||||
/* .id = */ "",
|
||||
}},
|
||||
},
|
||||
{
|
||||
/* .name = */ "tools with tool_choice = auto and parallel_tool_calls",
|
||||
/* .tools = */ create_tools(),
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ true,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"I must get the weather in New York and San Francisco and a 3 day forecast of each.</think>\nLet me search that for you."
|
||||
"<tool_call>["
|
||||
R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})"
|
||||
", "
|
||||
R"({"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}})"
|
||||
", "
|
||||
R"({"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}})"
|
||||
", "
|
||||
R"({"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}})"
|
||||
"]</tool_call>"
|
||||
),
|
||||
/* .expect_reasoning = */ "I must get the weather in New York and San Francisco and a 3 day forecast of each.",
|
||||
/* .expect_content = */ "Let me search that for you.",
|
||||
/* .expect_tool_calls = */ {{
|
||||
/* .name = */ "get_current_weather",
|
||||
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})",
|
||||
/* .id = */ "",
|
||||
}, {
|
||||
/* .name = */ "get_current_weather",
|
||||
/* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit"})",
|
||||
/* .id = */ "",
|
||||
}, {
|
||||
/* .name = */ "get_forecast",
|
||||
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit", "days": 3})",
|
||||
/* .id = */ "",
|
||||
}, {
|
||||
/* .name = */ "get_forecast",
|
||||
/* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3})",
|
||||
/* .id = */ "",
|
||||
}},
|
||||
},
|
||||
{
|
||||
/* .name = */ "response_format with thinking_forced_open = true",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"invoice_number", {{"type", "string"}}},
|
||||
{"amount", {{"type", "number"}}},
|
||||
{"due_date", {{"type", "string"}}}
|
||||
}},
|
||||
{"required", {"invoice_number", "amount", "due_date"}}
|
||||
},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"I must produce the invoice in the requested format</think>\n"
|
||||
R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})"
|
||||
),
|
||||
/* .expect_reasoning = */ "I must produce the invoice in the requested format",
|
||||
/* .expect_content = */ R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
};
|
||||
|
||||
for (const auto & tc : test_cases) {
|
||||
t.test(tc.name, [&](testing & t) {
|
||||
auto parser = build_parser(tc);
|
||||
auto lazy = !tc.tools.empty() && tc.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
auto grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
for (auto const & def : tc.tools) {
|
||||
auto function = def.at("function");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
};
|
||||
parser.build_grammar(builder, lazy);
|
||||
});
|
||||
|
||||
t.log("Grammar:");
|
||||
for (auto const & line : string_split(grammar, "\n")) {
|
||||
t.log(line);
|
||||
}
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_true("success", result.success());
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_native_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
t.assert_equal("content equal", tc.expect_content, msg.content);
|
||||
t.assert_equal("reasoning equal", tc.expect_reasoning, msg.reasoning_content);
|
||||
t.assert_equal("number of tool calls", tc.expect_tool_calls.size(), msg.tool_calls.size());
|
||||
for (auto i = 0u; i < std::min(tc.expect_tool_calls.size(), msg.tool_calls.size()); i++) {
|
||||
t.assert_equal("tool name", tc.expect_tool_calls[i].name, msg.tool_calls[i].name);
|
||||
t.assert_equal("tool args", tc.expect_tool_calls[i].arguments, msg.tool_calls[i].arguments);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void test_example_qwen3_coder(testing & t) {
|
||||
auto tools = create_tools();
|
||||
auto parser = build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) {
|
||||
auto content = p.rule("content", p.content(p.until("<tool_call>")));
|
||||
|
||||
std::vector<common_peg_parser> tool_parsers;
|
||||
for (auto const & def : tools) {
|
||||
auto function = def.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
auto properties = parameters.at("properties");
|
||||
|
||||
std::set<std::string> required_properties;
|
||||
if (function.contains("required")) {
|
||||
function.at("required").get_to(required_properties);
|
||||
}
|
||||
|
||||
std::vector<common_peg_parser> arg_parsers;
|
||||
for (const auto & [param_name, param_schema] : properties.items()) {
|
||||
bool is_required = required_properties.find(param_name) != required_properties.end();
|
||||
auto type = param_schema.value("type", "object");
|
||||
|
||||
auto arg = p.tool_arg(p.sequence({
|
||||
p.tool_arg_open("<parameter=" + p.tool_arg_name(p.literal(param_name)) + ">"),
|
||||
(type == "string" ?
|
||||
p.tool_arg_string_value(
|
||||
p.schema(
|
||||
p.until_one_of({
|
||||
"</parameter>\n<parameter=",
|
||||
"</parameter>\n</function>"
|
||||
}),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema,
|
||||
true
|
||||
)
|
||||
) : p.tool_arg_json_value(
|
||||
p.schema(
|
||||
p.json(),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema
|
||||
)
|
||||
)
|
||||
),
|
||||
p.tool_arg_close(
|
||||
"</parameter>\n" +
|
||||
p.peek(p.literal("<parameter=") | p.literal("</function>"))
|
||||
)
|
||||
}));
|
||||
|
||||
arg_parsers.push_back(is_required ?
|
||||
p.rule("tool-" + name + "-arg-" + param_name, arg) :
|
||||
p.optional(p.rule("tool-" + name + "-arg-" + param_name, arg)));
|
||||
}
|
||||
|
||||
tool_parsers.push_back(p.rule("tool-" + name,
|
||||
p.tool_open("<function=" + p.tool_name(p.literal(name)) + ">")
|
||||
<< p.sequence(arg_parsers)
|
||||
<< p.tool_close(p.literal("</function>"))
|
||||
));
|
||||
};
|
||||
|
||||
auto tool_call = p.trigger_rule("tool-call",
|
||||
"<tool_call>"
|
||||
<< p.choice(tool_parsers)
|
||||
<< "</tool_call>"
|
||||
);
|
||||
|
||||
return content + p.zero_or_more(p.space() + tool_call) + p.end();
|
||||
});
|
||||
|
||||
auto grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
for (auto const & def : tools) {
|
||||
auto function = def.at("function");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
};
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
t.log("Grammar:");
|
||||
for (auto const & line : string_split(grammar, "\n")) {
|
||||
t.log(line);
|
||||
}
|
||||
|
||||
t.test("incremental parsing", [&](testing &t) {
|
||||
std::string input =
|
||||
"Let me search the knowledge base for cat pictures."
|
||||
"<tool_call>\n"
|
||||
"<function=search_knowledge_base>\n"
|
||||
"<parameter=query>cat pictures</parameter>\n"
|
||||
"<parameter=category>general</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
|
||||
std::vector<std::string> tokens = simple_tokenize(input);
|
||||
|
||||
common_chat_msg prev;
|
||||
for (auto it = tokens.begin(); it != tokens.end(); it++) {
|
||||
std::string in = std::accumulate(tokens.begin(), it + 1, std::string());
|
||||
|
||||
common_peg_parse_context ctx(in, it + 1 < tokens.end());
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
if (!t.assert_equal("not fail", false, result.fail())) {
|
||||
t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end));
|
||||
}
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_constructed_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
//t.log("Input: " + input);
|
||||
t.log("===========================================");
|
||||
t.log("Iteration " + std::to_string(in.size()));
|
||||
t.log("Reasoning: " + msg.reasoning_content);
|
||||
t.log("Content : " + msg.content);
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
t.log("Tool name: " + tc.name);
|
||||
t.log("Tool args: " + tc.arguments);
|
||||
}
|
||||
|
||||
try {
|
||||
// This shouldn't emit any runtime errors
|
||||
auto diffs = common_chat_msg_diff::compute_diffs(prev, msg);
|
||||
} catch(const std::exception & e) {
|
||||
t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end));
|
||||
t.assert_true(std::string("failed with ") + e.what(), false);
|
||||
}
|
||||
|
||||
prev = msg;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void test_command7_parser_compare(testing & t) {
|
||||
auto parser = build_chat_peg_native_parser([](common_chat_peg_native_builder & p) {
|
||||
auto thinking = p.reasoning_block(
|
||||
"<|START_THINKING|>" << p.reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>");
|
||||
|
||||
auto response = "<|START_RESPONSE|>" << p.content(p.until("<|END_RESPONSE|>")) << "<|END_RESPONSE|>";
|
||||
|
||||
auto tool_call_id = p.atomic("\"tool_call_id\"" << (":" << ("\"" + p.tool_id(p.json_string_content()) + "\"")));
|
||||
auto tool_call_name = p.atomic("\"tool_name\"" << (":" << ("\"" + p.tool_name(p.json_string_content()) + "\"")));
|
||||
auto tool_call_args = "\"parameters\"" << (":" << p.tool_args(p.json()));
|
||||
|
||||
auto tool_call_fields = p.rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args);
|
||||
auto tool_call = p.rule("tool-call", p.tool(
|
||||
p.tool_open(p.literal("{"))
|
||||
<< tool_call_fields
|
||||
<< p.zero_or_more( p.literal(",") << tool_call_fields)
|
||||
<< p.tool_close(p.literal("}"))
|
||||
));
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
"<|START_ACTION|>"
|
||||
<< ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]")
|
||||
<< "<|END_ACTION|>");
|
||||
|
||||
return p.optional(thinking) << (tool_calls | response) + p.end();
|
||||
});
|
||||
|
||||
auto test_current = [&](const common_peg_arena & p, const std::string & input, bool is_partial, bool print_results) {
|
||||
common_peg_parse_context ctx(input, is_partial);
|
||||
auto result = p.parse(ctx);
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_native_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
if (print_results) {
|
||||
std::cout << "== Parsed (new) ==\n";
|
||||
std::cout << "=== Reasoning ===\n";
|
||||
std::cout << msg.reasoning_content << "\n";
|
||||
std::cout << "\n\n=== Content ===\n";
|
||||
std::cout << msg.content << "\n";
|
||||
std::cout << "\n\n=== Tool Calls ===\n";
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
std::cout << "id: " << tc.id << "\n";
|
||||
std::cout << "name: " << tc.name << "\n";
|
||||
std::cout << "args: " << tc.arguments << "\n";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) {
|
||||
// Original common_chat_combinator_parser taken from chat.cpp
|
||||
common_chat_msg_parser builder(
|
||||
input,
|
||||
/* .is_partial = */ need_more_input,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_GENERIC,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
}
|
||||
);
|
||||
|
||||
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
|
||||
|
||||
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
|
||||
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
|
||||
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
|
||||
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
|
||||
|
||||
if (auto res = builder.try_find_regex(start_action_regex)) {
|
||||
// If we didn't extract thoughts, prelude includes them.
|
||||
auto tool_calls = builder.consume_json_with_dumped_args({ { "parameters" } });
|
||||
for (const auto & tool_call : tool_calls.value) {
|
||||
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
|
||||
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
|
||||
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
|
||||
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
if (tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_regex(end_action_regex);
|
||||
} else if (auto res = builder.try_find_regex(start_response_regex)) {
|
||||
if (!builder.try_find_regex(end_response_regex)) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
throw common_chat_msg_partial_exception(end_response_regex.str());
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
if (print_results) {
|
||||
std::cout << "== Parsed (legacy) ==\n";
|
||||
std::cout << "=== Reasoning ===\n";
|
||||
std::cout << builder.result().reasoning_content << "\n";
|
||||
std::cout << "\n\n=== Content ===\n";
|
||||
std::cout << builder.result().content << "\n";
|
||||
std::cout << "\n\n=== Tool Calls ===\n";
|
||||
for (const auto & tc : builder.result().tool_calls) {
|
||||
std::cout << "id: " << tc.id << "\n";
|
||||
std::cout << "name: " << tc.name << "\n";
|
||||
std::cout << "args: " << tc.arguments << "\n";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::string reasoning = "To plan an effective trip to Japan that includes both historical sites and modern attractions within a "
|
||||
"budget of $4000 for a two-week stay, we need to:\n\n"
|
||||
"1. Identify key historical sites and modern attractions in Japan.\n"
|
||||
"2. Find affordable accommodation options that provide a balance between comfort and cost.\n"
|
||||
"3. Determine the best modes of transportation for getting around Japan.\n"
|
||||
"4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without "
|
||||
"overspending.\n"
|
||||
"5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees "
|
||||
"to attractions.";
|
||||
|
||||
std::vector<std::tuple<std::string, std::string, nlohmann::json>> tool_calls = {{
|
||||
"call_0",
|
||||
"plan_trip",
|
||||
nlohmann::json::parse(R"({
|
||||
"destination": "Japan",
|
||||
"duration": 14,
|
||||
"budget": 4000,
|
||||
"interests": ["historical sites", "modern attractions"],
|
||||
"accommodation_preferences": "affordable",
|
||||
"transportation_preferences": "efficient",
|
||||
"meal_preferences": "local cuisine"
|
||||
})")
|
||||
}};
|
||||
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
// Build tokens
|
||||
if (!reasoning.empty()) {
|
||||
auto tokenized = simple_tokenize(reasoning);
|
||||
tokens.emplace_back("<|START_THINKING|>");
|
||||
tokens.insert(tokens.end(), tokenized.begin(), tokenized.end());
|
||||
tokens.emplace_back("<|END_THINKING|>");
|
||||
}
|
||||
|
||||
if (!tool_calls.empty()) {
|
||||
tokens.emplace_back("<|START_ACTION|>");
|
||||
|
||||
auto json = nlohmann::json::array();
|
||||
for (const auto & tc : tool_calls) {
|
||||
auto tc_json = nlohmann::json::object();
|
||||
tc_json["tool_call_id"] = std::get<0>(tc);
|
||||
tc_json["tool_name"] = std::get<1>(tc);
|
||||
tc_json["parameters"] = std::get<2>(tc);
|
||||
json.push_back(tc_json);
|
||||
}
|
||||
|
||||
auto tokenized = simple_tokenize(json.dump(-1, ' ', true));
|
||||
tokens.insert(tokens.end(), tokenized.begin(), tokenized.end());
|
||||
|
||||
tokens.emplace_back("<|END_ACTION|>");
|
||||
}
|
||||
|
||||
std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string());
|
||||
|
||||
// Run tests
|
||||
t.test("legacy_parse", [&](testing & /* t */) {
|
||||
test_legacy(input, false, false);
|
||||
});
|
||||
|
||||
t.test("current_parse", [&](testing & /* t */) {
|
||||
test_current(parser, input, false, false);
|
||||
});
|
||||
|
||||
// Run benchmarks
|
||||
t.bench("legacy_parse_benchmark complete", [&]() {
|
||||
test_legacy(input, false, false);
|
||||
});
|
||||
|
||||
t.bench("legacy_parse_benchmark incremental", [&]() {
|
||||
std::string in;
|
||||
for (auto i = 0u; i < tokens.size(); i++) {
|
||||
in += tokens[i];
|
||||
|
||||
try {
|
||||
test_legacy(in, i + 1 < tokens.size(), false);
|
||||
} catch (common_chat_msg_partial_exception & /* e */) {
|
||||
// Do nothing, this is expected
|
||||
}
|
||||
}
|
||||
}, 20);
|
||||
|
||||
t.bench("current_parse_benchmark complete", [&]() {
|
||||
test_current(parser, input, false, false);
|
||||
}, 100);
|
||||
|
||||
t.bench("current_parse_benchmark incremental", [&]() {
|
||||
std::string in;
|
||||
for (auto i = 0u; i < tokens.size(); i++) {
|
||||
in += tokens[i];
|
||||
test_current(parser, in, i + 1 < tokens.size(), false);
|
||||
}
|
||||
}, 20);
|
||||
}
|
||||
25
tests/test-peg-parser.cpp
Normal file
25
tests/test-peg-parser.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
|
||||
#include "peg-parser/tests.h"
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
testing t(std::cout);
|
||||
if (argc >= 2) {
|
||||
t.set_filter(argv[1]);
|
||||
}
|
||||
|
||||
const char * verbose = getenv("LLAMA_TEST_VERBOSE");
|
||||
if (verbose) {
|
||||
t.verbose = std::string(verbose) == "1";
|
||||
}
|
||||
|
||||
t.test("basic", test_basic);
|
||||
t.test("unicode", test_unicode);
|
||||
t.test("json", test_json_parser);
|
||||
t.test("gbnf", test_gbnf_generation);
|
||||
t.test("serialization", test_json_serialization);
|
||||
|
||||
return t.summary();
|
||||
}
|
||||
@@ -2,11 +2,6 @@ set(TARGET llama-server)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
if (MINGW)
|
||||
# fix: https://github.com/ggml-org/llama.cpp/actions/runs/9651004652/job/26617901362?pr=8006
|
||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||
endif()
|
||||
|
||||
if (NOT LLAMA_HTTPLIB)
|
||||
message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF")
|
||||
endif()
|
||||
|
||||
Binary file not shown.
@@ -791,7 +791,7 @@ static void handle_media(
|
||||
SRV_INF("downloading image from '%s'\n", url.c_str());
|
||||
auto res = common_remote_get_content(url, params);
|
||||
if (200 <= res.first && res.first < 300) {
|
||||
SRV_INF("downloaded %ld bytes\n", res.second.size());
|
||||
SRV_INF("downloaded %zu bytes\n", res.second.size());
|
||||
raw_buffer data;
|
||||
data.insert(data.end(), res.second.begin(), res.second.end());
|
||||
out_files.push_back(data);
|
||||
@@ -1045,6 +1045,9 @@ json oaicompat_chat_params_parse(
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
if (!chat_params.parser.empty()) {
|
||||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <sheredom/subprocess.h>
|
||||
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
@@ -889,6 +890,28 @@ struct pipe_t {
|
||||
}
|
||||
};
|
||||
|
||||
static std::string to_lower_copy(const std::string & value) {
|
||||
std::string lowered(value.size(), '\0');
|
||||
std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
|
||||
return lowered;
|
||||
}
|
||||
|
||||
static bool should_strip_proxy_header(const std::string & header_name) {
|
||||
// Headers that get duplicated when router forwards child responses
|
||||
if (header_name == "server" ||
|
||||
header_name == "transfer-encoding" ||
|
||||
header_name == "keep-alive") {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Router injects CORS, child also sends them: duplicate
|
||||
if (header_name.rfind("access-control-", 0) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
server_http_proxy::server_http_proxy(
|
||||
const std::string & method,
|
||||
const std::string & host,
|
||||
@@ -925,6 +948,14 @@ server_http_proxy::server_http_proxy(
|
||||
msg_t msg;
|
||||
msg.status = response.status;
|
||||
for (const auto & [key, value] : response.headers) {
|
||||
const auto lowered = to_lower_copy(key);
|
||||
if (should_strip_proxy_header(lowered)) {
|
||||
continue;
|
||||
}
|
||||
if (lowered == "content-type") {
|
||||
msg.content_type = value;
|
||||
continue;
|
||||
}
|
||||
msg.headers[key] = value;
|
||||
}
|
||||
return pipe->write(std::move(msg)); // send headers first
|
||||
@@ -932,7 +963,7 @@ server_http_proxy::server_http_proxy(
|
||||
httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
|
||||
// send data chunks
|
||||
// returns false if pipe is closed / broken (signal to stop receiving)
|
||||
return pipe->write({{}, 0, std::string(data, data_length)});
|
||||
return pipe->write({{}, 0, std::string(data, data_length), ""});
|
||||
};
|
||||
|
||||
// prepare the request to destination server
|
||||
@@ -955,8 +986,8 @@ server_http_proxy::server_http_proxy(
|
||||
if (result.error() != httplib::Error::Success) {
|
||||
auto err_str = httplib::to_string(result.error());
|
||||
SRV_ERR("http client error: %s\n", err_str.c_str());
|
||||
pipe->write({{}, 500, ""}); // header
|
||||
pipe->write({{}, 0, "proxy error: " + err_str}); // body
|
||||
pipe->write({{}, 500, "", ""}); // header
|
||||
pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
|
||||
}
|
||||
pipe->close_write(); // signal EOF to reader
|
||||
SRV_DBG("%s", "client request thread ended\n");
|
||||
@@ -964,12 +995,17 @@ server_http_proxy::server_http_proxy(
|
||||
this->thread.detach();
|
||||
|
||||
// wait for the first chunk (headers)
|
||||
msg_t header;
|
||||
if (pipe->read(header, should_stop)) {
|
||||
SRV_DBG("%s", "received response headers\n");
|
||||
this->status = header.status;
|
||||
this->headers = header.headers;
|
||||
} else {
|
||||
SRV_DBG("%s", "no response headers received (request cancelled?)\n");
|
||||
{
|
||||
msg_t header;
|
||||
if (pipe->read(header, should_stop)) {
|
||||
SRV_DBG("%s", "received response headers\n");
|
||||
this->status = header.status;
|
||||
this->headers = std::move(header.headers);
|
||||
if (!header.content_type.empty()) {
|
||||
this->content_type = std::move(header.content_type);
|
||||
}
|
||||
} else {
|
||||
SRV_DBG("%s", "no response headers received (request cancelled?)\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,5 +170,6 @@ private:
|
||||
std::map<std::string, std::string> headers;
|
||||
int status = 0;
|
||||
std::string data;
|
||||
std::string content_type;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -297,6 +297,9 @@ task_params server_task::params_from_json_cmpl(
|
||||
params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
||||
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
||||
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
||||
if (data.contains("chat_parser")) {
|
||||
params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get<std::string>());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
@@ -65,6 +65,7 @@ def test_server_slots():
|
||||
|
||||
def test_load_split_model():
|
||||
global server
|
||||
server.offline = False
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
|
||||
server.model_alias = "tinyllama-split"
|
||||
|
||||
@@ -17,7 +17,6 @@ def create_server():
|
||||
]
|
||||
)
|
||||
def test_router_chat_completion_stream(model: str, success: bool):
|
||||
# TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server
|
||||
global server
|
||||
server.start()
|
||||
content = ""
|
||||
@@ -48,3 +47,148 @@ def test_router_chat_completion_stream(model: str, success: bool):
|
||||
else:
|
||||
assert ex is not None
|
||||
assert content == ""
|
||||
|
||||
|
||||
def _get_model_status(model_id: str) -> str:
|
||||
res = server.make_request("GET", "/models")
|
||||
assert res.status_code == 200
|
||||
for item in res.body.get("data", []):
|
||||
if item.get("id") == model_id or item.get("model") == model_id:
|
||||
return item["status"]["value"]
|
||||
raise AssertionError(f"Model {model_id} not found in /models response")
|
||||
|
||||
|
||||
def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
|
||||
deadline = time.time() + timeout
|
||||
last_status = None
|
||||
while time.time() < deadline:
|
||||
last_status = _get_model_status(model_id)
|
||||
if last_status in desired:
|
||||
return last_status
|
||||
time.sleep(1)
|
||||
raise AssertionError(
|
||||
f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
|
||||
)
|
||||
|
||||
|
||||
def _load_model_and_wait(
|
||||
model_id: str, timeout: int = 60, headers: dict | None = None
|
||||
) -> None:
|
||||
load_res = server.make_request(
|
||||
"POST", "/models/load", data={"model": model_id}, headers=headers
|
||||
)
|
||||
assert load_res.status_code == 200
|
||||
assert isinstance(load_res.body, dict)
|
||||
assert load_res.body.get("success") is True
|
||||
_wait_for_model_status(model_id, {"loaded"}, timeout=timeout)
|
||||
|
||||
|
||||
def test_router_unload_model():
|
||||
global server
|
||||
server.start()
|
||||
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||
|
||||
_load_model_and_wait(model_id)
|
||||
|
||||
unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
|
||||
assert unload_res.status_code == 200
|
||||
assert unload_res.body.get("success") is True
|
||||
_wait_for_model_status(model_id, {"unloaded"})
|
||||
|
||||
|
||||
def test_router_models_max_evicts_lru():
|
||||
global server
|
||||
server.models_max = 2
|
||||
server.start()
|
||||
|
||||
candidate_models = [
|
||||
"ggml-org/tinygemma3-GGUF:Q8_0",
|
||||
"ggml-org/test-model-stories260K",
|
||||
"ggml-org/test-model-stories260K-infill",
|
||||
]
|
||||
|
||||
# Load only the first 2 models to fill the cache
|
||||
first, second, third = candidate_models[:3]
|
||||
|
||||
_load_model_and_wait(first, timeout=120)
|
||||
_load_model_and_wait(second, timeout=120)
|
||||
|
||||
# Verify both models are loaded
|
||||
assert _get_model_status(first) == "loaded"
|
||||
assert _get_model_status(second) == "loaded"
|
||||
|
||||
# Load the third model - this should trigger LRU eviction of the first model
|
||||
_load_model_and_wait(third, timeout=120)
|
||||
|
||||
# Verify eviction: third is loaded, first was evicted
|
||||
assert _get_model_status(third) == "loaded"
|
||||
assert _get_model_status(first) == "unloaded"
|
||||
|
||||
|
||||
def test_router_no_models_autoload():
|
||||
global server
|
||||
server.no_models_autoload = True
|
||||
server.start()
|
||||
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||
|
||||
res = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
||||
_load_model_and_wait(model_id)
|
||||
|
||||
success_res = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert success_res.status_code == 200
|
||||
assert "error" not in success_res.body
|
||||
|
||||
|
||||
def test_router_api_key_required():
|
||||
global server
|
||||
server.api_key = "sk-router-secret"
|
||||
server.start()
|
||||
|
||||
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||
auth_headers = {"Authorization": f"Bearer {server.api_key}"}
|
||||
|
||||
res = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert res.status_code == 401
|
||||
assert res.body.get("error", {}).get("type") == "authentication_error"
|
||||
|
||||
_load_model_and_wait(model_id, headers=auth_headers)
|
||||
|
||||
authed = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers=auth_headers,
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert authed.status_code == 200
|
||||
assert "error" not in authed.body
|
||||
|
||||
@@ -7,6 +7,7 @@ import subprocess
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
import sys
|
||||
import requests
|
||||
import time
|
||||
@@ -83,6 +84,9 @@ class ServerProcess:
|
||||
pooling: str | None = None
|
||||
draft: int | None = None
|
||||
api_key: str | None = None
|
||||
models_dir: str | None = None
|
||||
models_max: int | None = None
|
||||
no_models_autoload: bool | None = None
|
||||
lora_files: List[str] | None = None
|
||||
enable_ctx_shift: int | None = False
|
||||
draft_min: int | None = None
|
||||
@@ -143,6 +147,10 @@ class ServerProcess:
|
||||
server_args.extend(["--hf-repo", self.model_hf_repo])
|
||||
if self.model_hf_file:
|
||||
server_args.extend(["--hf-file", self.model_hf_file])
|
||||
if self.models_dir:
|
||||
server_args.extend(["--models-dir", self.models_dir])
|
||||
if self.models_max is not None:
|
||||
server_args.extend(["--models-max", self.models_max])
|
||||
if self.n_batch:
|
||||
server_args.extend(["--batch-size", self.n_batch])
|
||||
if self.n_ubatch:
|
||||
@@ -204,6 +212,8 @@ class ServerProcess:
|
||||
server_args.extend(["--draft-min", self.draft_min])
|
||||
if self.no_webui:
|
||||
server_args.append("--no-webui")
|
||||
if self.no_models_autoload:
|
||||
server_args.append("--no-models-autoload")
|
||||
if self.jinja:
|
||||
server_args.append("--jinja")
|
||||
else:
|
||||
@@ -295,7 +305,13 @@ class ServerProcess:
|
||||
result = ServerResponse()
|
||||
result.headers = dict(response.headers)
|
||||
result.status_code = response.status_code
|
||||
result.body = response.json() if parse_body else None
|
||||
if parse_body:
|
||||
try:
|
||||
result.body = response.json()
|
||||
except JSONDecodeError:
|
||||
result.body = response.text
|
||||
else:
|
||||
result.body = None
|
||||
print("Response from server", json.dumps(result.body, indent=2))
|
||||
return result
|
||||
|
||||
@@ -434,8 +450,9 @@ class ServerPreset:
|
||||
@staticmethod
|
||||
def tinyllama2() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "tinyllamas/stories260K.gguf"
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
server.model_hf_repo = "ggml-org/test-model-stories260K"
|
||||
server.model_hf_file = None
|
||||
server.model_alias = "tinyllama-2"
|
||||
server.n_ctx = 512
|
||||
server.n_batch = 32
|
||||
@@ -479,8 +496,8 @@ class ServerPreset:
|
||||
def tinyllama_infill() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
|
||||
server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
|
||||
server.model_hf_file = None
|
||||
server.model_alias = "tinyllama-infill"
|
||||
server.n_ctx = 2048
|
||||
server.n_batch = 1024
|
||||
@@ -537,6 +554,7 @@ class ServerPreset:
|
||||
@staticmethod
|
||||
def router() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
# router server has no models
|
||||
server.model_file = None
|
||||
server.model_alias = None
|
||||
|
||||
@@ -15,7 +15,7 @@ sequenceDiagram
|
||||
Stores->>DB: load conversations
|
||||
Stores->>API: GET /props
|
||||
API-->>Stores: {role: "router"}
|
||||
Stores->>API: GET /models
|
||||
Stores->>API: GET /v1/models
|
||||
API-->>Stores: models[] with status (loaded/available)
|
||||
loop each loaded model
|
||||
Stores->>API: GET /props?model=X
|
||||
@@ -28,7 +28,7 @@ sequenceDiagram
|
||||
alt model not loaded
|
||||
Stores->>API: POST /models/load
|
||||
loop poll status
|
||||
Stores->>API: GET /models
|
||||
Stores->>API: GET /v1/models
|
||||
API-->>Stores: check if loaded
|
||||
end
|
||||
Stores->>API: GET /props?model=X
|
||||
|
||||
@@ -56,7 +56,7 @@ sequenceDiagram
|
||||
UI->>modelsStore: fetchRouterModels()
|
||||
activate modelsStore
|
||||
modelsStore->>ModelsSvc: listRouter()
|
||||
ModelsSvc->>API: GET /models
|
||||
ModelsSvc->>API: GET /v1/models
|
||||
API-->>ModelsSvc: ApiRouterModelsListResponse
|
||||
Note right of API: {data: [{id, status, path, in_cache}]}
|
||||
modelsStore->>modelsStore: routerModels = $state(data)
|
||||
@@ -132,7 +132,7 @@ sequenceDiagram
|
||||
loop poll every 500ms (max 60 attempts)
|
||||
modelsStore->>modelsStore: fetchRouterModels()
|
||||
modelsStore->>ModelsSvc: listRouter()
|
||||
ModelsSvc->>API: GET /models
|
||||
ModelsSvc->>API: GET /v1/models
|
||||
API-->>ModelsSvc: models[]
|
||||
modelsStore->>modelsStore: getModelStatus(modelId)
|
||||
alt status === LOADED
|
||||
@@ -165,7 +165,7 @@ sequenceDiagram
|
||||
modelsStore->>modelsStore: pollForModelStatus(modelId, UNLOADED)
|
||||
loop poll until unloaded
|
||||
modelsStore->>ModelsSvc: listRouter()
|
||||
ModelsSvc->>API: GET /models
|
||||
ModelsSvc->>API: GET /v1/models
|
||||
end
|
||||
|
||||
modelsStore->>modelsStore: modelLoadingStates.set(modelId, false)
|
||||
|
||||
@@ -64,7 +64,10 @@
|
||||
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
|
||||
let isRecording = $state(false);
|
||||
let message = $state('');
|
||||
let pasteLongTextToFileLength = $derived(Number(currentConfig.pasteLongTextToFileLen) || 2500);
|
||||
let pasteLongTextToFileLength = $derived.by(() => {
|
||||
const n = Number(currentConfig.pasteLongTextToFileLen);
|
||||
return Number.isNaN(n) ? 2500 : n;
|
||||
});
|
||||
let previousIsLoading = $state(isLoading);
|
||||
let recordingSupported = $state(false);
|
||||
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { ChatMessage } from '$lib/components/app';
|
||||
import { DatabaseService } from '$lib/services/database';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { getMessageSiblings } from '$lib/utils';
|
||||
@@ -19,7 +18,7 @@
|
||||
const conversation = activeConversation();
|
||||
|
||||
if (conversation) {
|
||||
DatabaseService.getConversationMessages(conversation.id).then((messages) => {
|
||||
conversationsStore.getConversationMessages(conversation.id).then((messages) => {
|
||||
allConversationMessages = messages;
|
||||
});
|
||||
} else {
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
import { Textarea } from '$lib/components/ui/textarea';
|
||||
import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config';
|
||||
import { settingsStore } from '$lib/stores/settings.svelte';
|
||||
import { ParameterSyncService } from '$lib/services/parameter-sync';
|
||||
import { ChatSettingsParameterSourceIndicator } from '$lib/components/app';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
@@ -22,7 +21,7 @@
|
||||
|
||||
// Helper function to get parameter source info for syncable parameters
|
||||
function getParameterSourceInfo(key: string) {
|
||||
if (!ParameterSyncService.canSyncParameter(key)) {
|
||||
if (!settingsStore.canSyncParameter(key)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
import { Download, Upload } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { DialogConversationSelection } from '$lib/components/app';
|
||||
import { DatabaseService } from '$lib/services/database';
|
||||
import { createMessageCountMap } from '$lib/utils';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||
|
||||
let exportedConversations = $state<DatabaseConversation[]>([]);
|
||||
let importedConversations = $state<DatabaseConversation[]>([]);
|
||||
@@ -21,15 +20,15 @@
|
||||
|
||||
async function handleExportClick() {
|
||||
try {
|
||||
const allConversations = await DatabaseService.getAllConversations();
|
||||
const allConversations = conversations();
|
||||
if (allConversations.length === 0) {
|
||||
alert('No conversations to export');
|
||||
return;
|
||||
}
|
||||
|
||||
const conversationsWithMessages = await Promise.all(
|
||||
allConversations.map(async (conv) => {
|
||||
const messages = await DatabaseService.getConversationMessages(conv.id);
|
||||
allConversations.map(async (conv: DatabaseConversation) => {
|
||||
const messages = await conversationsStore.getConversationMessages(conv.id);
|
||||
return { conv, messages };
|
||||
})
|
||||
);
|
||||
@@ -47,7 +46,7 @@
|
||||
try {
|
||||
const allData: ExportedConversations = await Promise.all(
|
||||
selectedConversations.map(async (conv) => {
|
||||
const messages = await DatabaseService.getConversationMessages(conv.id);
|
||||
const messages = await conversationsStore.getConversationMessages(conv.id);
|
||||
return { conv: $state.snapshot(conv), messages: $state.snapshot(messages) };
|
||||
})
|
||||
);
|
||||
@@ -135,9 +134,7 @@
|
||||
.snapshot(fullImportData)
|
||||
.filter((item) => selectedIds.has(item.conv.id));
|
||||
|
||||
await DatabaseService.importConversations(selectedData);
|
||||
|
||||
await conversationsStore.loadConversations();
|
||||
await conversationsStore.importConversationsData(selectedData);
|
||||
|
||||
importedConversations = selectedConversations;
|
||||
showImportSummary = true;
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
import * as Table from '$lib/components/ui/table';
|
||||
import { BadgeModality, CopyToClipboardIcon } from '$lib/components/app';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { modelsStore } from '$lib/stores/models.svelte';
|
||||
import { ChatService } from '$lib/services/chat';
|
||||
import { modelsStore, modelOptions, modelsLoading } from '$lib/stores/models.svelte';
|
||||
import { formatFileSize, formatParameters, formatNumber } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
@@ -16,38 +15,24 @@
|
||||
|
||||
let serverProps = $derived(serverStore.props);
|
||||
let modelName = $derived(modelsStore.singleModelName);
|
||||
let models = $derived(modelOptions());
|
||||
let isLoadingModels = $derived(modelsLoading());
|
||||
|
||||
// Get the first model for single-model mode display
|
||||
let firstModel = $derived(models[0] ?? null);
|
||||
|
||||
// Get modalities from modelStore using the model ID from the first model
|
||||
// For now it supports only for single-model mode, will be extended with further improvements for multi-model functioanlities
|
||||
let modalities = $derived.by(() => {
|
||||
if (!modelsData?.data?.[0]?.id) return [];
|
||||
|
||||
return modelsStore.getModelModalitiesArray(modelsData.data[0].id);
|
||||
if (!firstModel?.id) return [];
|
||||
return modelsStore.getModelModalitiesArray(firstModel.id);
|
||||
});
|
||||
|
||||
let modelsData = $state<ApiModelListResponse | null>(null);
|
||||
let isLoadingModels = $state(false);
|
||||
|
||||
// Fetch models data when dialog opens
|
||||
// Ensure models are fetched when dialog opens
|
||||
$effect(() => {
|
||||
if (open && !modelsData) {
|
||||
loadModelsData();
|
||||
if (open && models.length === 0) {
|
||||
modelsStore.fetch();
|
||||
}
|
||||
});
|
||||
|
||||
async function loadModelsData() {
|
||||
isLoadingModels = true;
|
||||
|
||||
try {
|
||||
modelsData = await ChatService.getModels();
|
||||
} catch (error) {
|
||||
console.error('Failed to load models data:', error);
|
||||
// Set empty data to prevent infinite loading
|
||||
modelsData = { object: 'list', data: [] };
|
||||
} finally {
|
||||
isLoadingModels = false;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open {onOpenChange}>
|
||||
@@ -70,8 +55,8 @@
|
||||
<div class="flex items-center justify-center py-8">
|
||||
<div class="text-sm text-muted-foreground">Loading model information...</div>
|
||||
</div>
|
||||
{:else if modelsData && modelsData.data.length > 0}
|
||||
{@const modelMeta = modelsData.data[0].meta}
|
||||
{:else if firstModel}
|
||||
{@const modelMeta = firstModel.meta}
|
||||
|
||||
{#if serverProps}
|
||||
<Table.Root>
|
||||
|
||||
@@ -126,8 +126,13 @@ export const TEXT_FILE_TYPES = {
|
||||
mimeTypes: [MimeTypeText.JAVA]
|
||||
},
|
||||
[FileTypeText.CPP]: {
|
||||
extensions: [FileExtensionText.CPP, FileExtensionText.C, FileExtensionText.H],
|
||||
mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.C_SRC, MimeTypeText.C_HDR]
|
||||
extensions: [
|
||||
FileExtensionText.CPP,
|
||||
FileExtensionText.C,
|
||||
FileExtensionText.H,
|
||||
FileExtensionText.HPP
|
||||
],
|
||||
mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.CPP_HDR, MimeTypeText.C_SRC, MimeTypeText.C_HDR]
|
||||
},
|
||||
[FileTypeText.PHP]: {
|
||||
extensions: [FileExtensionText.PHP],
|
||||
@@ -183,10 +188,30 @@ export const TEXT_FILE_TYPES = {
|
||||
},
|
||||
[FileTypeText.LATEX]: {
|
||||
extensions: [FileExtensionText.TEX],
|
||||
mimeTypes: [MimeTypeText.LATEX]
|
||||
mimeTypes: [MimeTypeText.LATEX, MimeTypeText.TEX, MimeTypeText.TEX_APP]
|
||||
},
|
||||
[FileTypeText.BIBTEX]: {
|
||||
extensions: [FileExtensionText.BIB],
|
||||
mimeTypes: [MimeTypeText.BIBTEX]
|
||||
},
|
||||
[FileTypeText.CUDA]: {
|
||||
extensions: [FileExtensionText.CU, FileExtensionText.CUH],
|
||||
mimeTypes: [MimeTypeText.CUDA]
|
||||
},
|
||||
[FileTypeText.VULKAN]: {
|
||||
extensions: [FileExtensionText.COMP],
|
||||
mimeTypes: [MimeTypeText.PLAIN]
|
||||
},
|
||||
[FileTypeText.HASKELL]: {
|
||||
extensions: [FileExtensionText.HS],
|
||||
mimeTypes: [MimeTypeText.HASKELL]
|
||||
},
|
||||
[FileTypeText.CSHARP]: {
|
||||
extensions: [FileExtensionText.CS],
|
||||
mimeTypes: [MimeTypeText.CSHARP]
|
||||
},
|
||||
[FileTypeText.PROPERTIES]: {
|
||||
extensions: [FileExtensionText.PROPERTIES],
|
||||
mimeTypes: [MimeTypeText.PROPERTIES]
|
||||
}
|
||||
} as const;
|
||||
|
||||
@@ -62,7 +62,12 @@ export enum FileTypeText {
|
||||
VUE = 'vue',
|
||||
SVELTE = 'svelte',
|
||||
LATEX = 'latex',
|
||||
BIBTEX = 'bibtex'
|
||||
BIBTEX = 'bibtex',
|
||||
CUDA = 'cuda',
|
||||
VULKAN = 'vulkan',
|
||||
HASKELL = 'haskell',
|
||||
CSHARP = 'csharp',
|
||||
PROPERTIES = 'properties'
|
||||
}
|
||||
|
||||
// File extension enums
|
||||
@@ -121,7 +126,14 @@ export enum FileExtensionText {
|
||||
VUE = '.vue',
|
||||
SVELTE = '.svelte',
|
||||
TEX = '.tex',
|
||||
BIB = '.bib'
|
||||
BIB = '.bib',
|
||||
CU = '.cu',
|
||||
CUH = '.cuh',
|
||||
COMP = '.comp',
|
||||
HPP = '.hpp',
|
||||
HS = '.hs',
|
||||
PROPERTIES = '.properties',
|
||||
CS = '.cs'
|
||||
}
|
||||
|
||||
// MIME type enums
|
||||
@@ -165,7 +177,10 @@ export enum MimeTypeText {
|
||||
CSV = 'text/csv',
|
||||
PYTHON = 'text/x-python',
|
||||
JAVA = 'text/x-java-source',
|
||||
CPP_HDR = 'text/x-c++hdr',
|
||||
CPP_SRC = 'text/x-c++src',
|
||||
CSHARP = 'text/x-csharp',
|
||||
HASKELL = 'text/x-haskell',
|
||||
C_SRC = 'text/x-csrc',
|
||||
C_HDR = 'text/x-chdr',
|
||||
PHP = 'text/x-php',
|
||||
@@ -182,6 +197,10 @@ export enum MimeTypeText {
|
||||
DART = 'text/x-dart',
|
||||
VUE = 'text/x-vue',
|
||||
SVELTE = 'text/x-svelte',
|
||||
LATEX = 'text/x-tex',
|
||||
BIBTEX = 'text/x-bibtex'
|
||||
TEX = 'text/x-tex',
|
||||
TEX_APP = 'application/x-tex',
|
||||
LATEX = 'application/x-latex',
|
||||
BIBTEX = 'text/x-bibtex',
|
||||
CUDA = 'text/x-cuda',
|
||||
PROPERTIES = 'text/properties'
|
||||
}
|
||||
|
||||
@@ -677,48 +677,6 @@ export class ChatService {
|
||||
// Utilities
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Get server properties - static method for API compatibility (to be refactored)
|
||||
*/
|
||||
static async getServerProps(): Promise<ApiLlamaCppServerProps> {
|
||||
try {
|
||||
const response = await fetch(`./props`, {
|
||||
headers: getJsonHeaders()
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch server props: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data;
|
||||
} catch (error) {
|
||||
console.error('Error fetching server props:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get model information from /models endpoint (to be refactored)
|
||||
*/
|
||||
static async getModels(): Promise<ApiModelListResponse> {
|
||||
try {
|
||||
const response = await fetch(`./models`, {
|
||||
headers: getJsonHeaders()
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch models: ${response.status} ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data;
|
||||
} catch (error) {
|
||||
console.error('Error fetching models:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Injects a system message at the beginning of the conversation if provided.
|
||||
* Checks for existing system messages to avoid duplication.
|
||||
|
||||
@@ -7,7 +7,7 @@ import { getJsonHeaders } from '$lib/utils';
|
||||
*
|
||||
* This service handles communication with model-related endpoints:
|
||||
* - `/v1/models` - OpenAI-compatible model list (MODEL + ROUTER mode)
|
||||
* - `/models` - Router-specific model management (ROUTER mode only)
|
||||
* - `/models/load`, `/models/unload` - Router-specific model management (ROUTER mode only)
|
||||
*
|
||||
* **Responsibilities:**
|
||||
* - List available models
|
||||
@@ -43,7 +43,7 @@ export class ModelsService {
|
||||
* Returns models with load status, paths, and other metadata
|
||||
*/
|
||||
static async listRouter(): Promise<ApiRouterModelsListResponse> {
|
||||
const response = await fetch(`${base}/models`, {
|
||||
const response = await fetch(`${base}/v1/models`, {
|
||||
headers: getJsonHeaders()
|
||||
});
|
||||
|
||||
|
||||
@@ -519,6 +519,19 @@ class ConversationsStore {
|
||||
return await DatabaseService.getConversationMessages(convId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Imports conversations from provided data (without file picker)
|
||||
* @param data - Array of conversation data with messages
|
||||
* @returns Import result with counts
|
||||
*/
|
||||
async importConversationsData(
|
||||
data: ExportedConversations
|
||||
): Promise<{ imported: number; skipped: number }> {
|
||||
const result = await DatabaseService.importConversations(data);
|
||||
await this.loadConversations();
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a message to the active messages array
|
||||
* Used by chatStore when creating new messages
|
||||
|
||||
@@ -370,6 +370,10 @@ class SettingsStore {
|
||||
return { ...this.config };
|
||||
}
|
||||
|
||||
canSyncParameter(key: string): boolean {
|
||||
return ParameterSyncService.canSyncParameter(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get parameter information including source for a specific parameter
|
||||
*/
|
||||
|
||||
@@ -77,6 +77,13 @@ export function getFileTypeCategory(mimeType: string): FileTypeCategory | null {
|
||||
case MimeTypeText.SVELTE:
|
||||
case MimeTypeText.LATEX:
|
||||
case MimeTypeText.BIBTEX:
|
||||
case MimeTypeText.CUDA:
|
||||
case MimeTypeText.CPP_HDR:
|
||||
case MimeTypeText.CSHARP:
|
||||
case MimeTypeText.HASKELL:
|
||||
case MimeTypeText.PROPERTIES:
|
||||
case MimeTypeText.TEX:
|
||||
case MimeTypeText.TEX_APP:
|
||||
return FileTypeCategory.TEXT;
|
||||
|
||||
default:
|
||||
@@ -144,6 +151,12 @@ export function getFileTypeCategoryByExtension(filename: string): FileTypeCatego
|
||||
case FileExtensionText.SVELTE:
|
||||
case FileExtensionText.TEX:
|
||||
case FileExtensionText.BIB:
|
||||
case FileExtensionText.COMP:
|
||||
case FileExtensionText.CU:
|
||||
case FileExtensionText.CUH:
|
||||
case FileExtensionText.HPP:
|
||||
case FileExtensionText.HS:
|
||||
case FileExtensionText.PROPERTIES:
|
||||
return FileTypeCategory.TEXT;
|
||||
|
||||
default:
|
||||
|
||||
3
vendor/cpp-httplib/CMakeLists.txt
vendored
3
vendor/cpp-httplib/CMakeLists.txt
vendored
@@ -144,4 +144,7 @@ if (CPPHTTPLIB_OPENSSL_SUPPORT)
|
||||
find_library(SECURITY_FRAMEWORK Security REQUIRED)
|
||||
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
|
||||
endif()
|
||||
if (WIN32 AND NOT MSVC)
|
||||
target_link_libraries(${TARGET} PUBLIC crypt32)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
Reference in New Issue
Block a user