mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-08 01:54:10 +00:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bde188d60f | ||
|
|
9d0229967a | ||
|
|
c4c10bfb86 | ||
|
|
817d743cc1 | ||
|
|
bd4ef13476 | ||
|
|
87a2084c45 | ||
|
|
3659aa28e9 | ||
|
|
2a73f81f8a | ||
|
|
7dba049b07 | ||
|
|
83c1171529 | ||
|
|
0d1324856f | ||
|
|
a67ef0f47f | ||
|
|
ef75a89fdb | ||
|
|
d8b5cdc4fe | ||
|
|
dea9ba27cb | ||
|
|
c6d1a00aa7 | ||
|
|
424c579455 | ||
|
|
e9f9483464 | ||
|
|
41c5e02f42 | ||
|
|
2e1c9cd814 | ||
|
|
190c4838bd | ||
|
|
e7c2cf1356 | ||
|
|
1257491047 | ||
|
|
083e18b11c | ||
|
|
3d94e967a1 | ||
|
|
7feb0a1005 | ||
|
|
0a8026e768 |
44
.github/workflows/build.yml
vendored
44
.github/workflows/build.yml
vendored
@@ -1602,33 +1602,33 @@ jobs:
|
||||
run: |
|
||||
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
# ggml-ci-x64-amd-vulkan:
|
||||
# runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
# run: |
|
||||
# vulkaninfo --summary
|
||||
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
# ggml-ci-x64-amd-rocm:
|
||||
# runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
# run: |
|
||||
# amd-smi static
|
||||
# GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-mac-metal:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
59
.github/workflows/release.yml
vendored
59
.github/workflows/release.yml
vendored
@@ -728,58 +728,6 @@ jobs:
|
||||
path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
|
||||
name: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
|
||||
|
||||
openEuler-cann:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [x86, aarch64]
|
||||
chip_type: ['910b', '310p']
|
||||
build: ['Release']
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
yum update -y
|
||||
yum install -y git gcc gcc-c++ make cmake libcurl-devel
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
|
||||
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||
-DGGML_CANN=on \
|
||||
-DSOC_TYPE=ascend${{ matrix.chip_type }}
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Pack artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/*
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts (zip)
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
|
||||
- name: Upload artifacts (tar)
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz
|
||||
|
||||
release:
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
|
||||
@@ -801,7 +749,6 @@ jobs:
|
||||
- macOS-arm64
|
||||
- macOS-x64
|
||||
- ios-xcode-build
|
||||
- openEuler-cann
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -893,12 +840,6 @@ jobs:
|
||||
- [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip)
|
||||
- [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip)
|
||||
|
||||
**openEuler:**
|
||||
- [openEuler x86 (310p)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-310p-openEuler-x86.tar.gz)
|
||||
- [openEuler x86 (910b)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-910b-openEuler-x86.tar.gz)
|
||||
- [openEuler aarch64 (310p)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-310p-openEuler-aarch64.tar.gz)
|
||||
- [openEuler aarch64 (910b)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-910b-openEuler-aarch64.tar.gz)
|
||||
|
||||
- name: Upload release
|
||||
id: upload_release
|
||||
uses: actions/github-script@v3
|
||||
|
||||
@@ -72,6 +72,12 @@ if (MSVC)
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
|
||||
endif()
|
||||
|
||||
if (LLAMA_STANDALONE)
|
||||
# enable parallel builds for msbuild
|
||||
list(APPEND CMAKE_VS_GLOBALS UseMultiToolTask=true)
|
||||
list(APPEND CMAKE_VS_GLOBALS EnforceProcessCountAcrossBuilds=true)
|
||||
endif()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
|
||||
set(LLAMA_TOOLS_INSTALL_DEFAULT OFF)
|
||||
else()
|
||||
@@ -193,11 +199,6 @@ if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML)
|
||||
# ... otherwise assume ggml is added by a parent CMakeLists.txt
|
||||
endif()
|
||||
|
||||
if (MINGW)
|
||||
# Target Windows 8 for PrefetchVirtualMemory
|
||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||
endif()
|
||||
|
||||
#
|
||||
# build the library
|
||||
#
|
||||
|
||||
@@ -10,13 +10,16 @@
|
||||
/common/arg.* @ggerganov
|
||||
/common/base64.hpp.* @ggerganov
|
||||
/common/build-info.* @ggerganov
|
||||
/common/chat-peg-parser.* @aldehir
|
||||
/common/common.* @ggerganov
|
||||
/common/console.* @ggerganov
|
||||
/common/http.* @angt
|
||||
/common/llguidance.* @ggerganov
|
||||
/common/log.* @ggerganov
|
||||
/common/peg-parser.* @aldehir
|
||||
/common/sampling.* @ggerganov
|
||||
/common/speculative.* @ggerganov
|
||||
/common/unicode.* @aldehir
|
||||
/convert_*.py @CISC
|
||||
/examples/batched.swift/ @ggerganov
|
||||
/examples/batched/ @ggerganov
|
||||
|
||||
@@ -39,26 +39,10 @@ if(Git_FOUND)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(MSVC)
|
||||
set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
|
||||
if (CMAKE_VS_PLATFORM_NAME)
|
||||
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
|
||||
else()
|
||||
set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
endif()
|
||||
else()
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_C_COMPILER} --version
|
||||
OUTPUT_VARIABLE OUT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
string(REGEX REPLACE " *\n.*" "" OUT "${OUT}")
|
||||
set(BUILD_COMPILER ${OUT})
|
||||
set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
|
||||
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_C_COMPILER} -dumpmachine
|
||||
OUTPUT_VARIABLE OUT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
set(BUILD_TARGET ${OUT})
|
||||
if(CMAKE_VS_PLATFORM_NAME)
|
||||
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
|
||||
else()
|
||||
set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
endif()
|
||||
|
||||
@@ -52,6 +52,8 @@ add_library(${TARGET} STATIC
|
||||
chat-parser.h
|
||||
chat-parser-xml-toolcall.h
|
||||
chat-parser-xml-toolcall.cpp
|
||||
chat-peg-parser.cpp
|
||||
chat-peg-parser.h
|
||||
chat.cpp
|
||||
chat.h
|
||||
common.cpp
|
||||
@@ -69,12 +71,16 @@ add_library(${TARGET} STATIC
|
||||
log.h
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
peg-parser.cpp
|
||||
peg-parser.h
|
||||
regex-partial.cpp
|
||||
regex-partial.h
|
||||
sampling.cpp
|
||||
sampling.h
|
||||
speculative.cpp
|
||||
speculative.h
|
||||
unicode.cpp
|
||||
unicode.h
|
||||
)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
|
||||
@@ -427,7 +427,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER) {
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#include "chat-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "peg-parser.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -1483,6 +1485,11 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
|
||||
return common_chat_peg_parse(syntax.parser, input, is_partial, syntax);
|
||||
}
|
||||
common_chat_msg_parser builder(input, is_partial, syntax);
|
||||
try {
|
||||
common_chat_parse(builder);
|
||||
@@ -1500,3 +1507,36 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
if (parser.empty()) {
|
||||
throw std::runtime_error("Failed to parse due to missing parser definition.");
|
||||
}
|
||||
|
||||
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(syntax.format), input.c_str());
|
||||
|
||||
common_peg_parse_context ctx(input, is_partial);
|
||||
auto result = parser.parse(ctx);
|
||||
if (result.fail()) {
|
||||
throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end));
|
||||
}
|
||||
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
|
||||
if (syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE) {
|
||||
auto mapper = common_chat_peg_native_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
} else if (syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
|
||||
auto mapper = common_chat_peg_constructed_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
} else {
|
||||
// Generic mapper
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
}
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
114
common/chat-peg-parser.cpp
Normal file
114
common/chat-peg-parser.cpp
Normal file
@@ -0,0 +1,114 @@
|
||||
#include "chat-peg-parser.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
static std::string_view trim_trailing_space(std::string_view sv) {
|
||||
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
|
||||
sv.remove_suffix(1);
|
||||
}
|
||||
return sv;
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
|
||||
arena.visit(result, [this](const common_peg_ast_node & node) {
|
||||
map(node);
|
||||
});
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
||||
bool is_reasoning = node.tag == common_chat_peg_builder::REASONING;
|
||||
bool is_content = node.tag == common_chat_peg_builder::CONTENT;
|
||||
|
||||
if (is_reasoning) {
|
||||
result.reasoning_content = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_content) {
|
||||
result.content = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) {
|
||||
common_chat_peg_mapper::map(node);
|
||||
|
||||
bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN;
|
||||
bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME;
|
||||
bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID;
|
||||
bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS;
|
||||
|
||||
if (is_tool_open) {
|
||||
result.tool_calls.emplace_back();
|
||||
current_tool = &result.tool_calls.back();
|
||||
}
|
||||
|
||||
if (is_tool_id && current_tool) {
|
||||
current_tool->id = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_tool_name && current_tool) {
|
||||
current_tool->name = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_tool_args && current_tool) {
|
||||
current_tool->arguments = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
|
||||
common_chat_peg_mapper::map(node);
|
||||
|
||||
bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN;
|
||||
bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME;
|
||||
bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE;
|
||||
bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN;
|
||||
bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE;
|
||||
bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME;
|
||||
bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE;
|
||||
bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE;
|
||||
|
||||
if (is_tool_open) {
|
||||
result.tool_calls.emplace_back();
|
||||
current_tool = &result.tool_calls.back();
|
||||
arg_count = 0;
|
||||
}
|
||||
|
||||
if (is_tool_name) {
|
||||
current_tool->name = std::string(node.text);
|
||||
current_tool->arguments = "{";
|
||||
}
|
||||
|
||||
if (is_arg_open) {
|
||||
needs_closing_quote = false;
|
||||
}
|
||||
|
||||
if (is_arg_name && current_tool) {
|
||||
if (arg_count > 0) {
|
||||
current_tool->arguments += ",";
|
||||
}
|
||||
current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":";
|
||||
++arg_count;
|
||||
}
|
||||
|
||||
if (is_arg_string && current_tool) {
|
||||
// Serialize to JSON, but exclude the end quote
|
||||
std::string dumped = json(node.text).dump();
|
||||
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
|
||||
needs_closing_quote = true;
|
||||
}
|
||||
|
||||
if (is_arg_close && current_tool) {
|
||||
if (needs_closing_quote) {
|
||||
current_tool->arguments += "\"";
|
||||
}
|
||||
}
|
||||
|
||||
if (is_arg_json && current_tool) {
|
||||
current_tool->arguments += std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_tool_close && current_tool) {
|
||||
current_tool->arguments += "}";
|
||||
}
|
||||
}
|
||||
105
common/chat-peg-parser.h
Normal file
105
common/chat-peg-parser.h
Normal file
@@ -0,0 +1,105 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
class common_chat_peg_builder : public common_peg_parser_builder {
|
||||
public:
|
||||
static constexpr const char * REASONING_BLOCK = "reasoning-block";
|
||||
static constexpr const char * REASONING = "reasoning";
|
||||
static constexpr const char * CONTENT = "content";
|
||||
|
||||
common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); }
|
||||
common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); }
|
||||
common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); }
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_parser(const std::function<common_peg_parser(common_chat_peg_builder & builder)> & fn) {
|
||||
common_chat_peg_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
class common_chat_peg_mapper {
|
||||
public:
|
||||
common_chat_msg & result;
|
||||
|
||||
common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {}
|
||||
|
||||
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
||||
virtual void map(const common_peg_ast_node & node);
|
||||
};
|
||||
|
||||
class common_chat_peg_native_builder : public common_chat_peg_builder {
|
||||
public:
|
||||
static constexpr const char * TOOL = "tool";
|
||||
static constexpr const char * TOOL_OPEN = "tool-open";
|
||||
static constexpr const char * TOOL_CLOSE = "tool-close";
|
||||
static constexpr const char * TOOL_ID = "tool-id";
|
||||
static constexpr const char * TOOL_NAME = "tool-name";
|
||||
static constexpr const char * TOOL_ARGS = "tool-args";
|
||||
|
||||
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
||||
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
||||
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
||||
common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); }
|
||||
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
||||
common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); }
|
||||
};
|
||||
|
||||
class common_chat_peg_native_mapper : public common_chat_peg_mapper {
|
||||
common_chat_tool_call * current_tool;
|
||||
|
||||
public:
|
||||
common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
|
||||
void map(const common_peg_ast_node & node) override;
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_native_parser(const std::function<common_peg_parser(common_chat_peg_native_builder & builder)> & fn) {
|
||||
common_chat_peg_native_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
class common_chat_peg_constructed_builder : public common_chat_peg_builder {
|
||||
public:
|
||||
static constexpr const char * TOOL = "tool";
|
||||
static constexpr const char * TOOL_OPEN = "tool-open";
|
||||
static constexpr const char * TOOL_CLOSE = "tool-close";
|
||||
static constexpr const char * TOOL_NAME = "tool-name";
|
||||
static constexpr const char * TOOL_ARG = "tool-arg";
|
||||
static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open";
|
||||
static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close";
|
||||
static constexpr const char * TOOL_ARG_NAME = "tool-arg-name";
|
||||
static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value";
|
||||
static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value";
|
||||
|
||||
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
||||
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
||||
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
||||
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
||||
common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); }
|
||||
common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); }
|
||||
common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); }
|
||||
common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); }
|
||||
common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); }
|
||||
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); }
|
||||
};
|
||||
|
||||
class common_chat_peg_constructed_mapper : public common_chat_peg_mapper {
|
||||
common_chat_tool_call * current_tool;
|
||||
int arg_count = 0;
|
||||
bool needs_closing_quote = false;
|
||||
|
||||
public:
|
||||
common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
|
||||
void map(const common_peg_ast_node & node) override;
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_constructed_parser(const std::function<common_peg_parser(common_chat_peg_constructed_builder & builder)> & fn) {
|
||||
common_chat_peg_constructed_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
@@ -85,29 +85,36 @@ json common_chat_msg::to_json_oaicompat() const
|
||||
return message;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) {
|
||||
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) {
|
||||
std::vector<common_chat_msg_diff> diffs;
|
||||
if (previous_msg.reasoning_content != new_msg.reasoning_content) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content);
|
||||
}
|
||||
if (previous_msg.content != new_msg.content) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.content_delta = string_diff(previous_msg.content, new_msg.content);
|
||||
if (msg_new.tool_calls.size() > msg_prv.tool_calls.size()) {
|
||||
diffs.reserve(msg_new.tool_calls.size() - msg_prv.tool_calls.size() + 3);
|
||||
} else {
|
||||
diffs.reserve(3);
|
||||
}
|
||||
|
||||
if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) {
|
||||
// TODO: these can become expensive for long messages - how to optimize?
|
||||
if (msg_prv.reasoning_content != msg_new.reasoning_content) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.reasoning_content_delta = string_diff(msg_prv.reasoning_content, msg_new.reasoning_content);
|
||||
}
|
||||
if (msg_prv.content != msg_new.content) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.content_delta = string_diff(msg_prv.content, msg_new.content);
|
||||
}
|
||||
|
||||
if (msg_new.tool_calls.size() < msg_prv.tool_calls.size()) {
|
||||
throw std::runtime_error("Invalid diff: now finding less tool calls!");
|
||||
}
|
||||
|
||||
if (!previous_msg.tool_calls.empty()) {
|
||||
auto idx = previous_msg.tool_calls.size() - 1;
|
||||
const auto & pref = previous_msg.tool_calls[idx];
|
||||
const auto & newf = new_msg.tool_calls[idx];
|
||||
if (!msg_prv.tool_calls.empty()) {
|
||||
const auto idx = msg_prv.tool_calls.size() - 1;
|
||||
const auto & pref = msg_prv.tool_calls[idx];
|
||||
const auto & newf = msg_new.tool_calls[idx];
|
||||
if (pref.name != newf.name) {
|
||||
throw std::runtime_error("Invalid diff: tool call mismatch!");
|
||||
}
|
||||
auto args_diff = string_diff(pref.arguments, newf.arguments);
|
||||
const auto args_diff = string_diff(pref.arguments, newf.arguments);
|
||||
if (!args_diff.empty() || pref.id != newf.id) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.tool_call_index = idx;
|
||||
@@ -118,11 +125,12 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
|
||||
diff.tool_call_delta.arguments = args_diff;
|
||||
}
|
||||
}
|
||||
for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) {
|
||||
for (size_t idx = msg_prv.tool_calls.size(); idx < msg_new.tool_calls.size(); ++idx) {
|
||||
auto & diff = diffs.emplace_back();
|
||||
diff.tool_call_index = idx;
|
||||
diff.tool_call_delta = new_msg.tool_calls[idx];
|
||||
diff.tool_call_delta = msg_new.tool_calls[idx];
|
||||
}
|
||||
|
||||
return diffs;
|
||||
}
|
||||
|
||||
@@ -649,6 +657,9 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
|
||||
case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple";
|
||||
case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native";
|
||||
case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "peg-parser.h"
|
||||
#include <functional>
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
@@ -76,7 +77,7 @@ struct common_chat_msg_diff {
|
||||
size_t tool_call_index = std::string::npos;
|
||||
common_chat_tool_call tool_call_delta;
|
||||
|
||||
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
|
||||
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new);
|
||||
|
||||
bool operator==(const common_chat_msg_diff & other) const {
|
||||
return content_delta == other.content_delta
|
||||
@@ -124,6 +125,11 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||
|
||||
// These are intended to be parsed by the PEG parser
|
||||
COMMON_CHAT_FORMAT_PEG_SIMPLE,
|
||||
COMMON_CHAT_FORMAT_PEG_NATIVE,
|
||||
COMMON_CHAT_FORMAT_PEG_CONSTRUCTED,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
@@ -154,6 +160,7 @@ struct common_chat_params {
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
};
|
||||
|
||||
struct common_chat_syntax {
|
||||
@@ -163,6 +170,7 @@ struct common_chat_syntax {
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool parse_tool_calls = true;
|
||||
common_peg_arena parser = {};
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
@@ -206,6 +214,7 @@ const char* common_chat_format_name(common_chat_format format);
|
||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||
|
||||
|
||||
@@ -786,11 +786,29 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
#include <iostream>
|
||||
|
||||
|
||||
#ifdef _WIN32
|
||||
static std::wstring utf8_to_wstring(const std::string & str) {
|
||||
if (str.empty()) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
|
||||
|
||||
if (size <= 0) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
std::wstring wstr(size, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
|
||||
|
||||
return wstr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// returns true if successful, false otherwise
|
||||
bool fs_create_directory_with_parents(const std::string & path) {
|
||||
#ifdef _WIN32
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
std::wstring wpath = converter.from_bytes(path);
|
||||
std::wstring wpath = utf8_to_wstring(path);
|
||||
|
||||
// if the path already exists, check whether it's a directory
|
||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||
|
||||
@@ -12,6 +12,10 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||
#define _WIN32_WINNT 0x0A00
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
#define DIRECTORY_SEPARATOR '\\'
|
||||
#else
|
||||
|
||||
1712
common/peg-parser.cpp
Normal file
1712
common/peg-parser.cpp
Normal file
File diff suppressed because it is too large
Load Diff
459
common/peg-parser.h
Normal file
459
common/peg-parser.h
Normal file
@@ -0,0 +1,459 @@
|
||||
#pragma once
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <variant>
|
||||
|
||||
struct common_grammar_builder;
|
||||
|
||||
class common_peg_parser_builder;
|
||||
|
||||
using common_peg_parser_id = size_t;
|
||||
constexpr common_peg_parser_id COMMON_PEG_INVALID_PARSER_ID = static_cast<common_peg_parser_id>(-1);
|
||||
|
||||
using common_peg_ast_id = size_t;
|
||||
constexpr common_peg_ast_id COMMON_PEG_INVALID_AST_ID = static_cast<common_peg_ast_id>(-1);
|
||||
|
||||
// Lightweight wrapper around common_peg_parser_id for convenience
|
||||
class common_peg_parser {
|
||||
common_peg_parser_id id_;
|
||||
common_peg_parser_builder & builder_;
|
||||
|
||||
public:
|
||||
common_peg_parser(const common_peg_parser & other) : id_(other.id_), builder_(other.builder_) {}
|
||||
common_peg_parser(common_peg_parser_id id, common_peg_parser_builder & builder) : id_(id), builder_(builder) {}
|
||||
|
||||
common_peg_parser & operator=(const common_peg_parser & other);
|
||||
common_peg_parser & operator+=(const common_peg_parser & other);
|
||||
common_peg_parser & operator|=(const common_peg_parser & other);
|
||||
|
||||
operator common_peg_parser_id() const { return id_; }
|
||||
common_peg_parser_id id() const { return id_; }
|
||||
|
||||
common_peg_parser_builder & builder() const { return builder_; }
|
||||
|
||||
// Creates a sequence
|
||||
common_peg_parser operator+(const common_peg_parser & other) const;
|
||||
|
||||
// Creates a sequence separated by spaces.
|
||||
common_peg_parser operator<<(const common_peg_parser & other) const;
|
||||
|
||||
// Creates a choice
|
||||
common_peg_parser operator|(const common_peg_parser & other) const;
|
||||
|
||||
common_peg_parser operator+(const char * str) const;
|
||||
common_peg_parser operator+(const std::string & str) const;
|
||||
common_peg_parser operator<<(const char * str) const;
|
||||
common_peg_parser operator<<(const std::string & str) const;
|
||||
common_peg_parser operator|(const char * str) const;
|
||||
common_peg_parser operator|(const std::string & str) const;
|
||||
};
|
||||
|
||||
common_peg_parser operator+(const char * str, const common_peg_parser & p);
|
||||
common_peg_parser operator+(const std::string & str, const common_peg_parser & p);
|
||||
common_peg_parser operator<<(const char * str, const common_peg_parser & p);
|
||||
common_peg_parser operator<<(const std::string & str, const common_peg_parser & p);
|
||||
common_peg_parser operator|(const char * str, const common_peg_parser & p);
|
||||
common_peg_parser operator|(const std::string & str, const common_peg_parser & p);
|
||||
|
||||
enum common_peg_parse_result_type {
|
||||
COMMON_PEG_PARSE_RESULT_FAIL = 0,
|
||||
COMMON_PEG_PARSE_RESULT_SUCCESS = 1,
|
||||
COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT = 2,
|
||||
};
|
||||
|
||||
const char * common_peg_parse_result_type_name(common_peg_parse_result_type type);
|
||||
|
||||
struct common_peg_ast_node {
|
||||
common_peg_ast_id id;
|
||||
std::string rule;
|
||||
std::string tag;
|
||||
size_t start;
|
||||
size_t end;
|
||||
std::string_view text;
|
||||
std::vector<common_peg_ast_id> children;
|
||||
|
||||
bool is_partial = false;
|
||||
};
|
||||
|
||||
struct common_peg_parse_result;
|
||||
|
||||
using common_peg_ast_visitor = std::function<void(const common_peg_ast_node & node)>;
|
||||
|
||||
class common_peg_ast_arena {
|
||||
std::vector<common_peg_ast_node> nodes_;
|
||||
public:
|
||||
common_peg_ast_id add_node(
|
||||
const std::string & rule,
|
||||
const std::string & tag,
|
||||
size_t start,
|
||||
size_t end,
|
||||
std::string_view text,
|
||||
std::vector<common_peg_ast_id> children,
|
||||
bool is_partial = false
|
||||
) {
|
||||
common_peg_ast_id id = nodes_.size();
|
||||
nodes_.push_back({id, rule, tag, start, end, text, std::move(children), is_partial});
|
||||
return id;
|
||||
}
|
||||
|
||||
const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); }
|
||||
|
||||
size_t size() const { return nodes_.size(); }
|
||||
|
||||
void clear() { nodes_.clear(); }
|
||||
|
||||
void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const;
|
||||
void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const;
|
||||
};
|
||||
|
||||
struct common_peg_parse_result {
|
||||
common_peg_parse_result_type type = COMMON_PEG_PARSE_RESULT_FAIL;
|
||||
size_t start = 0;
|
||||
size_t end = 0;
|
||||
|
||||
std::vector<common_peg_ast_id> nodes;
|
||||
|
||||
common_peg_parse_result() = default;
|
||||
|
||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start)
|
||||
: type(type), start(start), end(start) {}
|
||||
|
||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end)
|
||||
: type(type), start(start), end(end) {}
|
||||
|
||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end, std::vector<common_peg_ast_id> nodes)
|
||||
: type(type), start(start), end(end), nodes(std::move(nodes)) {}
|
||||
|
||||
bool fail() const { return type == COMMON_PEG_PARSE_RESULT_FAIL; }
|
||||
bool need_more_input() const { return type == COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT; }
|
||||
bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; }
|
||||
};
|
||||
|
||||
struct common_peg_parse_context {
|
||||
std::string input;
|
||||
bool is_partial;
|
||||
common_peg_ast_arena ast;
|
||||
|
||||
int parse_depth;
|
||||
|
||||
common_peg_parse_context()
|
||||
: is_partial(false), parse_depth(0) {}
|
||||
|
||||
common_peg_parse_context(const std::string & input)
|
||||
: input(input), is_partial(false), parse_depth(0) {}
|
||||
|
||||
common_peg_parse_context(const std::string & input, bool is_partial)
|
||||
: input(input), is_partial(is_partial), parse_depth(0) {}
|
||||
};
|
||||
|
||||
class common_peg_arena;
|
||||
|
||||
// Parser variants
|
||||
struct common_peg_epsilon_parser {};
|
||||
|
||||
struct common_peg_start_parser {};
|
||||
|
||||
struct common_peg_end_parser {};
|
||||
|
||||
struct common_peg_literal_parser {
|
||||
std::string literal;
|
||||
};
|
||||
|
||||
struct common_peg_sequence_parser {
|
||||
std::vector<common_peg_parser_id> children;
|
||||
};
|
||||
|
||||
struct common_peg_choice_parser {
|
||||
std::vector<common_peg_parser_id> children;
|
||||
};
|
||||
|
||||
struct common_peg_repetition_parser {
|
||||
common_peg_parser_id child;
|
||||
int min_count;
|
||||
int max_count; // -1 for unbounded
|
||||
};
|
||||
|
||||
struct common_peg_and_parser {
|
||||
common_peg_parser_id child;
|
||||
};
|
||||
|
||||
struct common_peg_not_parser {
|
||||
common_peg_parser_id child;
|
||||
};
|
||||
|
||||
struct common_peg_any_parser {};
|
||||
|
||||
struct common_peg_space_parser {};
|
||||
|
||||
struct common_peg_chars_parser {
|
||||
struct char_range {
|
||||
uint32_t start;
|
||||
uint32_t end;
|
||||
bool contains(uint32_t codepoint) const { return codepoint >= start && codepoint <= end; }
|
||||
};
|
||||
|
||||
std::string pattern;
|
||||
std::vector<char_range> ranges;
|
||||
bool negated;
|
||||
int min_count;
|
||||
int max_count; // -1 for unbounded
|
||||
};
|
||||
|
||||
struct common_peg_json_string_parser {};
|
||||
|
||||
struct common_peg_until_parser {
|
||||
std::vector<std::string> delimiters;
|
||||
};
|
||||
|
||||
struct common_peg_schema_parser {
|
||||
common_peg_parser_id child;
|
||||
std::string name;
|
||||
std::shared_ptr<nlohmann::ordered_json> schema;
|
||||
|
||||
// Indicates if the GBNF should accept a raw string that matches the schema.
|
||||
bool raw;
|
||||
};
|
||||
|
||||
struct common_peg_rule_parser {
|
||||
std::string name;
|
||||
common_peg_parser_id child;
|
||||
bool trigger;
|
||||
};
|
||||
|
||||
struct common_peg_ref_parser {
|
||||
std::string name;
|
||||
};
|
||||
|
||||
struct common_peg_atomic_parser {
|
||||
common_peg_parser_id child;
|
||||
};
|
||||
|
||||
struct common_peg_tag_parser {
|
||||
common_peg_parser_id child;
|
||||
std::string tag;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
common_peg_start_parser,
|
||||
common_peg_end_parser,
|
||||
common_peg_literal_parser,
|
||||
common_peg_sequence_parser,
|
||||
common_peg_choice_parser,
|
||||
common_peg_repetition_parser,
|
||||
common_peg_and_parser,
|
||||
common_peg_not_parser,
|
||||
common_peg_any_parser,
|
||||
common_peg_space_parser,
|
||||
common_peg_chars_parser,
|
||||
common_peg_json_string_parser,
|
||||
common_peg_until_parser,
|
||||
common_peg_schema_parser,
|
||||
common_peg_rule_parser,
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
std::vector<common_peg_parser_variant> parsers_;
|
||||
std::unordered_map<std::string, common_peg_parser_id> rules_;
|
||||
common_peg_parser_id root_ = COMMON_PEG_INVALID_PARSER_ID;
|
||||
|
||||
public:
|
||||
const common_peg_parser_variant & get(common_peg_parser_id id) const { return parsers_.at(id); }
|
||||
common_peg_parser_variant & get(common_peg_parser_id id) { return parsers_.at(id); }
|
||||
|
||||
size_t size() const { return parsers_.size(); }
|
||||
bool empty() const { return parsers_.empty(); }
|
||||
|
||||
common_peg_parser_id get_rule(const std::string & name) const;
|
||||
bool has_rule(const std::string & name) const { return rules_.find(name) != rules_.end(); }
|
||||
|
||||
common_peg_parser_id root() const { return root_; }
|
||||
void set_root(common_peg_parser_id id) { root_ = id; }
|
||||
|
||||
common_peg_parse_result parse(common_peg_parse_context & ctx, size_t start = 0) const;
|
||||
common_peg_parse_result parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const;
|
||||
|
||||
void resolve_refs();
|
||||
|
||||
void build_grammar(const common_grammar_builder & builder, bool lazy = false) const;
|
||||
|
||||
std::string dump(common_peg_parser_id id) const;
|
||||
|
||||
nlohmann::json to_json() const;
|
||||
static common_peg_arena from_json(const nlohmann::json & j);
|
||||
|
||||
std::string save() const;
|
||||
void load(const std::string & data);
|
||||
|
||||
friend class common_peg_parser_builder;
|
||||
|
||||
private:
|
||||
common_peg_parser_id add_parser(common_peg_parser_variant parser);
|
||||
void add_rule(const std::string & name, common_peg_parser_id id);
|
||||
|
||||
common_peg_parser_id resolve_ref(common_peg_parser_id id);
|
||||
};
|
||||
|
||||
class common_peg_parser_builder {
|
||||
common_peg_arena arena_;
|
||||
|
||||
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
|
||||
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
|
||||
|
||||
public:
|
||||
common_peg_parser_builder();
|
||||
|
||||
// Match nothing, always succeed.
|
||||
// S -> ε
|
||||
common_peg_parser eps() { return add(common_peg_epsilon_parser{}); }
|
||||
|
||||
// Matches the start of the input.
|
||||
// S -> ^
|
||||
common_peg_parser start() { return add(common_peg_start_parser{}); }
|
||||
|
||||
// Matches the end of the input.
|
||||
// S -> $
|
||||
common_peg_parser end() { return add(common_peg_end_parser{}); }
|
||||
|
||||
// Matches an exact literal string.
|
||||
// S -> "hello"
|
||||
common_peg_parser literal(const std::string & literal) { return add(common_peg_literal_parser{literal}); }
|
||||
|
||||
// Matches a sequence of parsers in order, all must succeed.
|
||||
// S -> A B C
|
||||
common_peg_parser sequence() { return add(common_peg_sequence_parser{}); }
|
||||
common_peg_parser sequence(const std::vector<common_peg_parser_id> & parsers);
|
||||
common_peg_parser sequence(const std::vector<common_peg_parser> & parsers);
|
||||
common_peg_parser sequence(std::initializer_list<common_peg_parser> parsers);
|
||||
|
||||
// Matches the first parser that succeeds from a list of alternatives.
|
||||
// S -> A | B | C
|
||||
common_peg_parser choice() { return add(common_peg_choice_parser{}); }
|
||||
common_peg_parser choice(const std::vector<common_peg_parser_id> & parsers);
|
||||
common_peg_parser choice(const std::vector<common_peg_parser> & parsers);
|
||||
common_peg_parser choice(std::initializer_list<common_peg_parser> parsers);
|
||||
|
||||
// Matches one or more repetitions of a parser.
|
||||
// S -> A+
|
||||
common_peg_parser one_or_more(const common_peg_parser & p) { return repeat(p, 1, -1); }
|
||||
|
||||
// Matches zero or more repetitions of a parser, always succeeds.
|
||||
// S -> A*
|
||||
common_peg_parser zero_or_more(const common_peg_parser & p) { return repeat(p, 0, -1); }
|
||||
|
||||
// Matches zero or one occurrence of a parser, always succeeds.
|
||||
// S -> A?
|
||||
common_peg_parser optional(const common_peg_parser & p) { return repeat(p, 0, 1); }
|
||||
|
||||
// Positive lookahead: succeeds if child parser succeeds, consumes no input.
|
||||
// S -> &A
|
||||
common_peg_parser peek(const common_peg_parser & p) { return add(common_peg_and_parser{p}); }
|
||||
|
||||
// Negative lookahead: succeeds if child parser fails, consumes no input.
|
||||
// S -> !A
|
||||
common_peg_parser negate(const common_peg_parser & p) { return add(common_peg_not_parser{p}); }
|
||||
|
||||
// Matches any single character.
|
||||
// S -> .
|
||||
common_peg_parser any() { return add(common_peg_any_parser{}); }
|
||||
|
||||
// Matches between min and max repetitions of characters from a character class.
|
||||
// S -> [a-z]{m,n}
|
||||
//
|
||||
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
|
||||
common_peg_parser chars(const std::string & classes, int min = 1, int max = -1);
|
||||
|
||||
// Creates a lightweight reference to a named rule (resolved during build()).
|
||||
// Use this for forward references in recursive grammars.
|
||||
// expr_ref -> expr
|
||||
common_peg_parser ref(const std::string & name) { return add(common_peg_ref_parser{name}); }
|
||||
|
||||
// Matches zero or more whitespace characters (space, tab, newline).
|
||||
// S -> [ \t\n]*
|
||||
common_peg_parser space() { return add(common_peg_space_parser{}); }
|
||||
|
||||
// Matches all characters until a delimiter is found (delimiter not consumed).
|
||||
// S -> (!delim .)*
|
||||
common_peg_parser until(const std::string & delimiter) { return add(common_peg_until_parser{{delimiter}}); }
|
||||
|
||||
// Matches all characters until one of the delimiters in the list is found (delimiter not consumed).
|
||||
// S -> (!delim .)*
|
||||
common_peg_parser until_one_of(const std::vector<std::string> & delimiters) { return add(common_peg_until_parser{delimiters}); }
|
||||
|
||||
// Matches everything
|
||||
// S -> .*
|
||||
common_peg_parser rest() { return until_one_of({}); }
|
||||
|
||||
// Matches between min and max repetitions of a parser (inclusive).
|
||||
// S -> A{m,n}
|
||||
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
|
||||
common_peg_parser repeat(const common_peg_parser & p, int min, int max) { return add(common_peg_repetition_parser{p, min,max}); }
|
||||
|
||||
// Matches exactly n repetitions of a parser.
|
||||
// S -> A{n}
|
||||
common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); }
|
||||
|
||||
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
|
||||
// value -> object | array | string | number | true | false | null
|
||||
common_peg_parser json();
|
||||
common_peg_parser json_object();
|
||||
common_peg_parser json_string();
|
||||
common_peg_parser json_array();
|
||||
common_peg_parser json_number();
|
||||
common_peg_parser json_bool();
|
||||
common_peg_parser json_null();
|
||||
|
||||
// Matches JSON string content without the surrounding quotes.
|
||||
// Useful for extracting content within a JSON string.
|
||||
common_peg_parser json_string_content();
|
||||
|
||||
// Matches a JSON object member with a key and associated parser as the
|
||||
// value.
|
||||
common_peg_parser json_member(const std::string & key, const common_peg_parser & p);
|
||||
|
||||
// Wraps a parser with JSON schema metadata for grammar generation.
|
||||
// Used internally to convert JSON schemas to GBNF grammar rules.
|
||||
common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false);
|
||||
|
||||
// Creates a named rule, stores it in the grammar, and returns a ref.
|
||||
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
|
||||
// auto json = p.rule("json", json_obj | json_arr | ...)
|
||||
common_peg_parser rule(const std::string & name, const common_peg_parser & p, bool trigger = false);
|
||||
|
||||
// Creates a named rule using a builder function, and returns a ref.
|
||||
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
|
||||
// auto json = p.rule("json", [&]() { return json_object() | json_array() | ... })
|
||||
common_peg_parser rule(const std::string & name, const std::function<common_peg_parser()> & builder, bool trigger = false);
|
||||
|
||||
// Creates a trigger rule. When generating a lazy grammar from the parser,
|
||||
// only trigger rules and descendents are emitted.
|
||||
common_peg_parser trigger_rule(const std::string & name, const common_peg_parser & p) { return rule(name, p, true); }
|
||||
common_peg_parser trigger_rule(const std::string & name, const std::function<common_peg_parser()> & builder) { return rule(name, builder, true); }
|
||||
|
||||
// Creates an atomic parser. Atomic parsers do not create an AST node if
|
||||
// the child results in a partial parse, i.e. NEEDS_MORE_INPUT. This is
|
||||
// intended for situations where partial output is undesirable.
|
||||
common_peg_parser atomic(const common_peg_parser & p) { return add(common_peg_atomic_parser{p}); }
|
||||
|
||||
// Tags create nodes in the generated AST for semantic purposes.
|
||||
// Unlike rules, you can tag multiple nodes with the same tag.
|
||||
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
};
|
||||
|
||||
// Helper function for building parsers
|
||||
common_peg_arena build_peg_parser(const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn);
|
||||
64
common/unicode.cpp
Normal file
64
common/unicode.cpp
Normal file
@@ -0,0 +1,64 @@
|
||||
#include "unicode.h"
|
||||
|
||||
// implementation adopted from src/unicode.cpp
|
||||
|
||||
size_t utf8_sequence_length(unsigned char first_byte) {
|
||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t highbits = static_cast<uint8_t>(first_byte) >> 4;
|
||||
return lookup[highbits];
|
||||
}
|
||||
|
||||
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) {
|
||||
if (offset >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
|
||||
// ASCII fast path
|
||||
if (!(input[offset] & 0x80)) {
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, input[offset], 1);
|
||||
}
|
||||
|
||||
// Invalid: continuation byte as first byte
|
||||
if (!(input[offset] & 0x40)) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
|
||||
// 2-byte sequence
|
||||
if (!(input[offset] & 0x20)) {
|
||||
if (offset + 1 >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
if ((input[offset + 1] & 0xc0) != 0x80) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
auto result = ((input[offset] & 0x1f) << 6) | (input[offset + 1] & 0x3f);
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 2);
|
||||
}
|
||||
|
||||
// 3-byte sequence
|
||||
if (!(input[offset] & 0x10)) {
|
||||
if (offset + 2 >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
auto result = ((input[offset] & 0x0f) << 12) | ((input[offset + 1] & 0x3f) << 6) | (input[offset + 2] & 0x3f);
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 3);
|
||||
}
|
||||
|
||||
// 4-byte sequence
|
||||
if (!(input[offset] & 0x08)) {
|
||||
if (offset + 3 >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80 || (input[offset + 3] & 0xc0) != 0x80) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
auto result = ((input[offset] & 0x07) << 18) | ((input[offset + 1] & 0x3f) << 12) | ((input[offset + 2] & 0x3f) << 6) | (input[offset + 3] & 0x3f);
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 4);
|
||||
}
|
||||
|
||||
// Invalid first byte
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
22
common/unicode.h
Normal file
22
common/unicode.h
Normal file
@@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <string_view>
|
||||
|
||||
// UTF-8 parsing utilities for streaming-aware unicode support
|
||||
|
||||
struct utf8_parse_result {
|
||||
uint32_t codepoint; // Decoded codepoint (only valid if status == SUCCESS)
|
||||
size_t bytes_consumed; // How many bytes this codepoint uses (1-4)
|
||||
enum status { SUCCESS, INCOMPLETE, INVALID } status;
|
||||
|
||||
utf8_parse_result(enum status s, uint32_t cp = 0, size_t bytes = 0)
|
||||
: codepoint(cp), bytes_consumed(bytes), status(s) {}
|
||||
};
|
||||
|
||||
// Determine the expected length of a UTF-8 sequence from its first byte
|
||||
// Returns 0 for invalid first bytes
|
||||
size_t utf8_sequence_length(unsigned char first_byte);
|
||||
|
||||
// Parse a single UTF-8 codepoint from input
|
||||
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset);
|
||||
@@ -2341,9 +2341,18 @@ class LlamaModel(TextModel):
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(False)
|
||||
|
||||
template_dir = Path(__file__).parent / "models/templates/"
|
||||
local_template_file_path = self.dir_model / "chat_template.jinja"
|
||||
|
||||
if self.is_mistral_format and local_template_file_path.is_file():
|
||||
# Ministral-3 and other new Mistral models come with chat templates.
|
||||
# ref: https://huggingface.co/mistralai/Ministral-3-14B-Instruct-2512/tree/main
|
||||
logger.info("Using an existing Mistral local chat template.")
|
||||
|
||||
with open(local_template_file_path, "r", encoding="utf-8") as f:
|
||||
template = f.read()
|
||||
elif not self.is_mistral_format or not self.disable_mistral_community_chat_template:
|
||||
template_dir = Path(__file__).parent / "models/templates/"
|
||||
|
||||
if not self.is_mistral_format or not self.disable_mistral_community_chat_template:
|
||||
# Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
|
||||
if self.is_mistral_format:
|
||||
logger.info(
|
||||
@@ -2351,9 +2360,12 @@ class LlamaModel(TextModel):
|
||||
"Mistral recommends to use `mistral-common` to perform tokenization and detokenization."
|
||||
)
|
||||
template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format)
|
||||
self.gguf_writer.add_chat_template(template)
|
||||
else:
|
||||
logger.info("Not using a Mistral community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
|
||||
logger.info("Not using a Mistral local or community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
|
||||
template = None
|
||||
|
||||
if template is not None:
|
||||
self.gguf_writer.add_chat_template(template)
|
||||
|
||||
def set_vocab(self):
|
||||
if self.is_mistral_format:
|
||||
|
||||
288
docs/development/parsing.md
Normal file
288
docs/development/parsing.md
Normal file
@@ -0,0 +1,288 @@
|
||||
# Parsing Model Output
|
||||
|
||||
The `common` library contains a PEG parser implementation suitable for parsing
|
||||
model output.
|
||||
|
||||
Types with the prefix `common_peg_*` are intended for general use and may have
|
||||
applications beyond parsing model output, such as parsing user-provided regex
|
||||
patterns.
|
||||
|
||||
Types with the prefix `common_chat_peg_*` are specialized helpers for model
|
||||
output.
|
||||
|
||||
The parser features:
|
||||
|
||||
- Partial parsing of streaming input
|
||||
- Built-in JSON parsers
|
||||
- AST generation with semantics via "tagged" nodes
|
||||
|
||||
## Example
|
||||
|
||||
Below is a contrived example demonstrating how to use the PEG parser to parse
|
||||
output from a model that emits arguments as JSON.
|
||||
|
||||
```cpp
|
||||
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
|
||||
// Build a choice of all available tools
|
||||
auto tool_choice = p.choice();
|
||||
for (const auto & tool : tools) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
const auto & schema = function.at("parameters");
|
||||
|
||||
auto tool_name = p.json_member("name", "\"" + p.literal(name) + "\"");
|
||||
auto tool_args = p.json_member("arguments", p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, "{" << tool_name << "," << tool_args << "}");
|
||||
}
|
||||
|
||||
// Define the tool call structure: <tool_call>[{tool}]</tool_call>
|
||||
auto tool_call = p.trigger_rule("tool-call",
|
||||
p.sequence({
|
||||
p.literal("<tool_call>["),
|
||||
tool_choice,
|
||||
p.literal("]</tool_call>")
|
||||
})
|
||||
);
|
||||
|
||||
// Parser accepts content, optionally followed by a tool call
|
||||
return p.sequence({
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.optional(tool_call),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
For a more complete example, see `test_example_native()` in
|
||||
[tests/test-chat-peg-parser.cpp](tests/test-chat-peg-parser.cpp).
|
||||
|
||||
## Parsers/Combinators
|
||||
|
||||
### Basic Matchers
|
||||
|
||||
- **`eps()`** - Matches nothing and always succeeds (epsilon/empty match)
|
||||
- **`start()`** - Matches the start of input (anchor `^`)
|
||||
- **`end()`** - Matches the end of input (anchor `$`)
|
||||
- **`literal(string)`** - Matches an exact literal string
|
||||
- **`any()`** - Matches any single character (`.`)
|
||||
|
||||
### Combinators
|
||||
|
||||
- **`sequence(...)`** - Matches parsers in order; all must succeed
|
||||
- **`choice(...)`** - Matches the first parser that succeeds from alternatives (ordered choice)
|
||||
- **`one_or_more(p)`** - Matches one or more repetitions (`+`)
|
||||
- **`zero_or_more(p)`** - Matches zero or more repetitions (`*`)
|
||||
- **`optional(p)`** - Matches zero or one occurrence (`?`)
|
||||
- **`repeat(p, min, max)`** - Matches between min and max repetitions (use `-1` for unbounded)
|
||||
- **`repeat(p, n)`** - Matches exactly n repetitions
|
||||
|
||||
### Lookahead
|
||||
|
||||
- **`peek(p)`** - Positive lookahead: succeeds if parser succeeds without consuming input (`&`)
|
||||
- **`negate(p)`** - Negative lookahead: succeeds if parser fails without consuming input (`!`)
|
||||
|
||||
### Character Classes & Utilities
|
||||
|
||||
- **`chars(classes, min, max)`** - Matches repetitions of characters from a character class
|
||||
- **`space()`** - Matches zero or more whitespace characters (space, tab, newline)
|
||||
- **`until(delimiter)`** - Matches characters until delimiter is found (delimiter not consumed)
|
||||
- **`until_one_of(delimiters)`** - Matches characters until any delimiter in the list is found
|
||||
- **`rest()`** - Matches everything remaining (`.*`)
|
||||
|
||||
### JSON Parsers
|
||||
|
||||
- **`json()`** - Complete JSON parser (objects, arrays, strings, numbers, booleans, null)
|
||||
- **`json_object()`** - JSON object parser
|
||||
- **`json_array()`** - JSON array parser
|
||||
- **`json_string()`** - JSON string parser
|
||||
- **`json_number()`** - JSON number parser
|
||||
- **`json_bool()`** - JSON boolean parser
|
||||
- **`json_null()`** - JSON null parser
|
||||
- **`json_string_content()`** - JSON string content without surrounding quotes
|
||||
- **`json_member(key, p)`** - JSON object member with specific key and value parser
|
||||
|
||||
### Grammar Building
|
||||
|
||||
- **`ref(name)`** - Creates a lightweight reference to a named rule (for recursive grammars)
|
||||
- **`rule(name, p, trigger)`** - Creates a named rule and returns a reference
|
||||
- **`trigger_rule(name, p)`** - Creates a trigger rule (entry point for lazy grammar generation)
|
||||
- **`schema(p, name, schema, raw)`** - Wraps parser with JSON schema metadata for grammar generation
|
||||
|
||||
### AST Control
|
||||
|
||||
- **`atomic(p)`** - Prevents AST node creation for partial parses
|
||||
- **`tag(tag, p)`** - Creates AST nodes with semantic tags (multiple nodes can share tags)
|
||||
|
||||
## GBNF Grammar Generation
|
||||
|
||||
The PEG parser also acts as a convenient DSL for generating GBNF grammars, with
|
||||
some exceptions.
|
||||
|
||||
```cpp
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(params.tools, [&](const json & fn) {
|
||||
builder.resolve_refs(fn.at("parameters"));
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
```
|
||||
|
||||
The notable exception is the `negate(p)` lookahead parser, which cannot be
|
||||
defined as a CFG grammar and therefore does not produce a rule. Its usage
|
||||
should be limited and preferably hidden behind a `schema()` parser. In many
|
||||
cases, `until(delimiter)` or `until_one_of(delimiters)` is a better choice.
|
||||
|
||||
Another limitation is that the PEG parser requires an unambiguous grammar. In
|
||||
contrast, the `llama-grammar` implementation can support ambiguous grammars,
|
||||
though they are difficult to parse.
|
||||
|
||||
### Lazy Grammars
|
||||
|
||||
During lazy grammar generation, only rules reachable from a `trigger_rule(p)`
|
||||
are emitted in the grammar. All trigger rules are added as alternations in the
|
||||
root rule. It is still necessary to define trigger patterns, as the parser has
|
||||
no interaction with the grammar sampling.
|
||||
|
||||
### JSON Schema
|
||||
|
||||
The `schema(p, name, schema, raw)` parser will use the `json-schema-to-grammar`
|
||||
implementation to generate the grammar instead of the underlying parser.
|
||||
|
||||
The `raw` option emits a grammar suitable for a raw string instead of a JSON
|
||||
string. In other words, it won't be wrapped in quotes or require escaping
|
||||
quotes. It should only be used when `type == "string"`.
|
||||
|
||||
The downside is that it can potentially lead to ambiguous grammars. For
|
||||
example, if a user provides the pattern `^.*$`, the following grammar may be
|
||||
generated:
|
||||
|
||||
```
|
||||
root ::= "<arg>" .* "</arg>"
|
||||
```
|
||||
|
||||
This creates an ambiguous grammar that cannot be parsed by the PEG parser. To
|
||||
help mitigate this, if `.*` is found in the pattern, the grammar from the
|
||||
underlying parser will be emitted instead.
|
||||
|
||||
## Common AST Shapes for Chat Parsing
|
||||
|
||||
Most model output can be placed in one of the following categories:
|
||||
|
||||
- Content only
|
||||
- Tool calling with arguments emitted as a single JSON object
|
||||
- Tool calling with arguments emitted as separate entities, either XML
|
||||
(Qwen3-Coder, MiniMax M2) or pseudo-function calls (LFM2)
|
||||
|
||||
To provide broad coverage,
|
||||
[`common/chat-peg-parser.h`](common/chat-peg-parser.h) contains builders and
|
||||
mappers that help create parsers and visitors/extractors for these types. They
|
||||
require parsers to tag nodes to conform to an AST "shape". This normalization
|
||||
makes it easy to extract information and generalize parsing.
|
||||
|
||||
### Simple
|
||||
|
||||
The `common_chat_peg_builder` builds a `simple` parser that supports
|
||||
content-only models with optional reasoning.
|
||||
|
||||
- **`reasoning(p)`** - Tag node for extracting `reasoning_content`
|
||||
- **`content(p)`** - Tag node for extracting `content`
|
||||
|
||||
```cpp
|
||||
build_chat_peg_parser([&](common_chat_peg_parser & p) {
|
||||
return p.sequence({
|
||||
p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>"),
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
Use `common_chat_peg_mapper` to extract the content. Note that this is already
|
||||
done for you in `common_chat_peg_parser` when
|
||||
`chat_format == COMMON_CHAT_FORMAT_PEG_SIMPLE`.
|
||||
|
||||
```cpp
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
```
|
||||
|
||||
### Native
|
||||
|
||||
The `common_chat_peg_native_builder` builds a `native` parser suitable for
|
||||
models that emit tool arguments as a direct JSON object.
|
||||
|
||||
- **`reasoning(p)`** - Tag node for `reasoning_content`
|
||||
- **`content(p)`** - Tag node for `content`
|
||||
- **`tool(p)`** - Tag entirety of a single tool call
|
||||
- **`tool_open(p)`** - Tag start of a tool call
|
||||
- **`tool_close(p)`** - Tag end of a tool call
|
||||
- **`tool_id(p)`** - Tag the tool call ID (optional)
|
||||
- **`tool_name(p)`** - Tag the tool name
|
||||
- **`tool_args(p)`** - Tag the tool arguments
|
||||
|
||||
```cpp
|
||||
build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) {
|
||||
auto get_weather_tool = p.tool(p.sequence({
|
||||
p.tool_open(p.literal("{")),
|
||||
p.json_member("name", "\"" + p.tool_name(p.literal("get_weather")) + "\""),
|
||||
p.literal(","),
|
||||
p.json_member("arguments", p.tool_args(p.json())),
|
||||
p.tool_close(p.literal("}"))
|
||||
}));
|
||||
|
||||
return p.sequence({
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.literal("<tool_call>"),
|
||||
get_weather_tool,
|
||||
p.literal("</tool_call>"),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Constructed
|
||||
|
||||
The `common_chat_peg_constructed_builder` builds a `constructed` parser
|
||||
suitable for models that emit tool arguments as separate entities, such as XML
|
||||
tags.
|
||||
|
||||
- **`reasoning(p)`** - Tag node for `reasoning_content`
|
||||
- **`content(p)`** - Tag node for `content`
|
||||
- **`tool(p)`** - Tag entirety of a single tool call
|
||||
- **`tool_open(p)`** - Tag start of a tool call
|
||||
- **`tool_close(p)`** - Tag end of a tool call
|
||||
- **`tool_name(p)`** - Tag the tool name
|
||||
- **`tool_arg(p)`** - Tag a complete tool argument (name + value)
|
||||
- **`tool_arg_open(p)`** - Tag start of a tool argument
|
||||
- **`tool_arg_close(p)`** - Tag end of a tool argument
|
||||
- **`tool_arg_name(p)`** - Tag the argument name
|
||||
- **`tool_arg_string_value(p)`** - Tag string value for the argument
|
||||
- **`tool_arg_json_value(p)`** - Tag JSON value for the argument
|
||||
|
||||
```cpp
|
||||
build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) {
|
||||
auto location_arg = p.tool_arg(
|
||||
p.tool_arg_open("<parameter name=\"" + p.tool_arg_name(p.literal("location")) + "\">"),
|
||||
p.tool_arg_string_value(p.until("</parameter>")),
|
||||
p.tool_arg_close(p.literal("</parameter>"))
|
||||
);
|
||||
|
||||
auto get_weather_tool = p.tool(p.sequence({
|
||||
p.tool_open("<function name=\"" + p.tool_name(p.literal("get_weather")) + "\">"),
|
||||
location_arg,
|
||||
p.tool_close(p.literal("</function>"))
|
||||
}));
|
||||
|
||||
return p.sequence({
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.literal("<tool_call>"),
|
||||
get_weather_tool,
|
||||
p.literal("</tool_call>"),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
@@ -18,6 +18,7 @@ cd llama.cpp
|
||||
cmake -S . -B build
|
||||
cmake --build build
|
||||
cmake --install build --prefix inst
|
||||
```
|
||||
|
||||
### Build simple-cmake-pkg
|
||||
|
||||
|
||||
@@ -175,11 +175,6 @@ option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requi
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC")
|
||||
|
||||
|
||||
if (MINGW)
|
||||
set(GGML_WIN_VER "0xA00" CACHE STRING "ggml: Windows version")
|
||||
endif()
|
||||
|
||||
# ggml core
|
||||
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
|
||||
option(GGML_CPU "ggml: enable CPU backend" ON)
|
||||
|
||||
@@ -204,6 +204,10 @@
|
||||
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||
# define _WIN32_WINNT 0x0A00
|
||||
#endif
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
@@ -2279,7 +2283,7 @@ extern "C" {
|
||||
float stop,
|
||||
float step);
|
||||
|
||||
#define GGML_KQ_MASK_PAD 64
|
||||
#define GGML_KQ_MASK_PAD 1
|
||||
|
||||
// q: [n_embd_k, n_batch, n_head, ne3 ]
|
||||
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
|
||||
|
||||
@@ -127,10 +127,6 @@ if (NOT MSVC)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (MINGW)
|
||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||
endif()
|
||||
|
||||
#
|
||||
# POSIX conformance
|
||||
#
|
||||
|
||||
@@ -505,7 +505,6 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
@@ -645,7 +644,6 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
|
||||
@@ -6383,7 +6383,7 @@ static void ggml_compute_forward_im2col_3d_f16(
|
||||
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
||||
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
||||
|
||||
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
||||
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
||||
} else {
|
||||
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
||||
|
||||
@@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
@@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max(
|
||||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
|
||||
const int nbatch_fa) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
@@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
@@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results(
|
||||
template <int DV, int ncols1, int ncols2>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
@@ -790,8 +791,6 @@ void launch_fattn(
|
||||
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
@@ -915,7 +914,7 @@ void launch_fattn(
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
|
||||
} else {
|
||||
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
@@ -970,6 +969,9 @@ void launch_fattn(
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
// TODO other tensor dimensions after removal of WMMA kernel:
|
||||
const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
|
||||
|
||||
GGML_ASSERT(block_dim.x % warp_size == 0);
|
||||
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
||||
(const char *) Q->data,
|
||||
@@ -980,7 +982,7 @@ void launch_fattn(
|
||||
KV_max.ptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
||||
@@ -995,7 +997,7 @@ void launch_fattn(
|
||||
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -501,6 +501,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
||||
const half2 * const __restrict__ K_h2,
|
||||
const half2 * const __restrict__ V_h2,
|
||||
const half * const __restrict__ mask,
|
||||
const uint3 ne01,
|
||||
const float logit_softcap,
|
||||
const float slope,
|
||||
T_KQ * const KQ,
|
||||
@@ -512,7 +513,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
||||
float * const KQ_sum,
|
||||
T_acc * const VKQ,
|
||||
const int k_VKQ_0,
|
||||
const int k_VKQ_max) {
|
||||
const int k_VKQ_max,
|
||||
const int col_Q_0) {
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
@@ -556,7 +558,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
||||
// Apply logit softcap + mask, update KQ_max:
|
||||
#pragma unroll
|
||||
for (int jc0 = 0; jc0 < cpw; ++jc0) {
|
||||
const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2;
|
||||
const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01);
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
|
||||
@@ -736,7 +738,7 @@ static __global__ void flash_attn_tile(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
@@ -781,11 +783,11 @@ static __global__ void flash_attn_tile(
|
||||
const int sequence = blockIdx.z / (ne02/ncols2);
|
||||
const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0);
|
||||
const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
|
||||
|
||||
const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr;
|
||||
const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr;
|
||||
|
||||
const int stride_K2 = nb11 / sizeof(half2);
|
||||
const int stride_V2 = nb21 / sizeof(half2);
|
||||
@@ -842,11 +844,9 @@ static __global__ void flash_attn_tile(
|
||||
for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
|
||||
if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
|
||||
float tmp_f[cpy_ne_D] = {0.0f};
|
||||
if (ncols1 == 1 || col_Q_0 + j < ne01) {
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)>
|
||||
(tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float))
|
||||
+ i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)>
|
||||
(tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
|
||||
+ i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
|
||||
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
@@ -881,23 +881,23 @@ static __global__ void flash_attn_tile(
|
||||
while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
|
||||
constexpr bool oob_check = false;
|
||||
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
|
||||
(Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
|
||||
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
|
||||
k_VKQ_0 += gridDim.y*nbatch_fa;
|
||||
}
|
||||
if (k_VKQ_0 < k_VKQ_max) {
|
||||
constexpr bool oob_check = true;
|
||||
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
|
||||
(Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
|
||||
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
|
||||
}
|
||||
} else {
|
||||
// Branch without out-of-bounds checks.
|
||||
for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
|
||||
constexpr bool oob_check = false;
|
||||
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
|
||||
(Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
|
||||
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
|
||||
stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1010,13 +1010,13 @@ static __global__ void flash_attn_tile(
|
||||
const int j = jc / ncols2;
|
||||
const int c = jc % ncols2;
|
||||
|
||||
if (ncols1 > 1 && col_Q_0 + j >= ne01) {
|
||||
if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
|
||||
const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
|
||||
|
||||
@@ -33,7 +33,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
@@ -150,7 +150,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
// Set memory to zero if out of bounds:
|
||||
if (ncols > 1 && ic0 + j >= ne01) {
|
||||
if (ncols > 1 && ic0 + j >= int(ne01.z)) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
@@ -201,7 +201,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
|
||||
|
||||
float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
|
||||
if (ncols == 1 || ic0 + j < ne01) {
|
||||
if (ncols == 1 || ic0 + j < int(ne01.z)) {
|
||||
ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
|
||||
}
|
||||
@@ -222,7 +222,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
|
||||
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
|
||||
if (ncols == 1 || ic0 + j < ne01) {
|
||||
if (ncols == 1 || ic0 + j < int(ne01.z)) {
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
|
||||
}
|
||||
@@ -266,7 +266,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) {
|
||||
sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
|
||||
}
|
||||
|
||||
@@ -412,7 +412,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||
if (ncols > 1 && ic0 + j_VKQ >= ne01) {
|
||||
if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -479,7 +479,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= KQ_sum[j_VKQ];
|
||||
}
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
|
||||
dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,8 +489,8 @@ static __global__ void flash_attn_ext_vec(
|
||||
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) {
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) {
|
||||
dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
|
||||
@@ -38,14 +38,14 @@ static __global__ void flash_attn_ext_f16(
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -149,7 +149,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
if (i0 + warp_size > D && i >= D) {
|
||||
break;
|
||||
}
|
||||
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||
KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,7 +218,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
||||
KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
|
||||
__half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
||||
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
|
||||
}
|
||||
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
|
||||
@@ -270,7 +271,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
||||
KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
||||
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
|
||||
}
|
||||
KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||
@@ -431,7 +432,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j_VKQ = j0 + threadIdx.y;
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
if (ic0 + j_VKQ >= int(ne01.z)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -442,7 +443,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
||||
}
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
||||
@@ -481,7 +482,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
|
||||
}
|
||||
|
||||
constexpr int get_max_power_of_2(int x) {
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
||||
#if defined(GGML_USE_MUSA)
|
||||
#define GGML_USE_WMMA_FATTN
|
||||
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
||||
#endif // defined(GGML_USE_MUSA)
|
||||
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
|
||||
|
||||
@@ -12,13 +12,13 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if constexpr (ncols2 <= 8) {
|
||||
if (Q->ne[1] <= 8/ncols2) {
|
||||
if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 16/ncols2) {
|
||||
if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
@@ -41,7 +41,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
@@ -275,8 +275,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// If Turing tensor cores available, use them:
|
||||
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
// If Turing tensor cores are available, use them:
|
||||
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||
@@ -297,7 +297,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
|
||||
int gqa_ratio_eff = 1;
|
||||
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
|
||||
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
|
||||
gqa_ratio_eff *= 2;
|
||||
}
|
||||
if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
if (Q->ne[1] * gqa_ratio_eff <= 16) {
|
||||
return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.
|
||||
}
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
|
||||
@@ -68,10 +68,31 @@ static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
|
||||
|
||||
namespace ggml_cuda_mma {
|
||||
|
||||
// Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
|
||||
// effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
|
||||
// In those cases the data can be split in different ways across the warp.
|
||||
enum data_layout {
|
||||
// By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
|
||||
// For the A/C matrices this means I major == row major, J major == column major.
|
||||
// For the B matrix this means I major == column major, J major == row major.
|
||||
// MIRRORED == Each data value is held exactly once per thread subgroup.
|
||||
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
|
||||
DATA_LAYOUT_I_MAJOR_MIRRORED = 10,
|
||||
DATA_LAYOUT_J_MAJOR_MIRRORED = 20,
|
||||
};
|
||||
// Implemented mma combinations are:
|
||||
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
||||
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
||||
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
||||
|
||||
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
||||
struct tile {};
|
||||
|
||||
template <int I_, int J_, typename T>
|
||||
struct tile {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 64;
|
||||
@@ -131,9 +152,9 @@ namespace ggml_cuda_mma {
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 32 && J == 8) {
|
||||
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
|
||||
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
|
||||
#else
|
||||
return (l & 2) | (threadIdx.x & ~2);
|
||||
return (l & 2) + (threadIdx.x & ~2);
|
||||
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -143,7 +164,7 @@ namespace ggml_cuda_mma {
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 32 && J == 8) {
|
||||
return (threadIdx.x & 2) | (l & (4 + 1));
|
||||
return (threadIdx.x & 2) + (l & (4 + 1));
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -196,9 +217,9 @@ namespace ggml_cuda_mma {
|
||||
} else if constexpr (I == 8 && J == 8) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((l / 2) * 8) | (threadIdx.x / 4);
|
||||
return ((l / 2) * 8) + (threadIdx.x / 4);
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
|
||||
return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
|
||||
} else {
|
||||
@@ -211,11 +232,11 @@ namespace ggml_cuda_mma {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 8 && J == 8) {
|
||||
return (l * 4) | (threadIdx.x % 4);
|
||||
return (l * 4) + (threadIdx.x % 4);
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((threadIdx.x % 4) * 2) | (l % 2);
|
||||
return ((threadIdx.x % 4) * 2) + (l % 2);
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
|
||||
return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
|
||||
} else {
|
||||
@@ -227,26 +248,24 @@ namespace ggml_cuda_mma {
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, half2> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
||||
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 8) return true;
|
||||
if (I == 32 && J == 8) return true;
|
||||
if (I == 32 && J == 4) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
if constexpr (I == 32 && J == 4) {
|
||||
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
|
||||
return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
||||
#else
|
||||
return threadIdx.x;
|
||||
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
@@ -257,7 +276,7 @@ namespace ggml_cuda_mma {
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr ((I == 8 || I == 32) && J == 8) {
|
||||
if constexpr (I == 32 && J == 4) {
|
||||
return l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -307,11 +326,11 @@ namespace ggml_cuda_mma {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return (l * 8) | (threadIdx.x / 4);
|
||||
return (l * 8) + (threadIdx.x / 4);
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((l % 2) * 8) | (threadIdx.x / 4);
|
||||
return ((l % 2) * 8) + (threadIdx.x / 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
|
||||
return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -320,13 +339,13 @@ namespace ggml_cuda_mma {
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return (l * 4) | (threadIdx.x % 4);
|
||||
return (l * 4) + (threadIdx.x % 4);
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((l / 2) * 4) | (threadIdx.x % 4);
|
||||
return ((l / 2) * 4) + (threadIdx.x % 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return ((l & 2) * 2) | (threadIdx.x % 4);
|
||||
return ((l & 2) * 2) + (threadIdx.x % 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -336,14 +355,15 @@ namespace ggml_cuda_mma {
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, nv_bfloat162> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr int ne = I * J / 32;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
@@ -367,9 +387,6 @@ namespace ggml_cuda_mma {
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 8) return true;
|
||||
if (I == 16 && J == 4) return true;
|
||||
@@ -381,9 +398,9 @@ namespace ggml_cuda_mma {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return (l * 8) | (threadIdx.x / 4);
|
||||
return (l * 8) + (threadIdx.x / 4);
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((l % 2) * 8) | (threadIdx.x / 4);
|
||||
return ((l % 2) * 8) + (threadIdx.x / 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -392,11 +409,11 @@ namespace ggml_cuda_mma {
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return (l * 4) | (threadIdx.x % 4);
|
||||
return (l * 4) + (threadIdx.x % 4);
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return ((l / 2) * 4) | (threadIdx.x % 4);
|
||||
return ((l / 2) * 4) + (threadIdx.x % 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
@@ -405,6 +422,73 @@ namespace ggml_cuda_mma {
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||
static constexpr int ne = I * J / (WARP_SIZE/4);
|
||||
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 4) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int /*l*/) {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
|
||||
static constexpr int ne = I * J / (WARP_SIZE/4);
|
||||
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 4) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return ((l / 2) * 4) + (threadIdx.x % 4);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return ((threadIdx.x / 16) * 2) + (l % 2);
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
template <int I, int J>
|
||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||
tile<I, J/2, half2> ret;
|
||||
@@ -422,9 +506,26 @@ namespace ggml_cuda_mma {
|
||||
|
||||
return ret;
|
||||
}
|
||||
#else // Volta
|
||||
template <int I, int J>
|
||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||
tile<I, J/2, half2> ret;
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
|
||||
ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
|
||||
ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
|
||||
|
||||
template <int I, int J, typename T>
|
||||
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
// On Volta FP16 and FP32 tiles have a different memory layout,
|
||||
// for the conversion threads with an offset of 2 need to exchange half their values:
|
||||
ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
|
||||
0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
#endif // defined(TURING_MMA_AVAILABLE)
|
||||
|
||||
template <int I, int J, typename T, data_layout dl>
|
||||
static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||
#pragma unroll
|
||||
@@ -511,18 +612,6 @@ namespace ggml_cuda_mma {
|
||||
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
GGML_UNUSED_VARS(t, xs0, stride);
|
||||
NO_DEVICE_CODE;
|
||||
#else
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#if 1
|
||||
// TODO: more generic handling
|
||||
@@ -533,9 +622,31 @@ namespace ggml_cuda_mma {
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // 1
|
||||
#else
|
||||
tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
|
||||
load_ldmatrix(t16[0], xs0 + 0*stride, stride);
|
||||
load_ldmatrix(t16[1], xs0 + 16*stride, stride);
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
||||
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < t.ne; l0 += 2) {
|
||||
ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
||||
#else
|
||||
GGML_UNUSED_VARS(t, xs0, stride);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
@@ -860,14 +971,14 @@ namespace ggml_cuda_mma {
|
||||
template <typename T1, typename T2, int J, int K>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
|
||||
tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
|
||||
tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
|
||||
tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
|
||||
const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
|
||||
mma(D16[0], A16[0], B);
|
||||
mma(D16[1], A16[1], B);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||
tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
@@ -880,20 +991,30 @@ namespace ggml_cuda_mma {
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
||||
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
|
||||
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
|
||||
#else
|
||||
tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
|
||||
const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
|
||||
mma(D16[0], A16[0], B);
|
||||
mma(D16[1], A16[1], B);
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
|
||||
asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
|
||||
@@ -37,23 +37,19 @@ static __global__ void mul_mat_f(
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<tile_B_I, 8, T> tile_B;
|
||||
typedef tile<16, tile_C_J, float> tile_C;
|
||||
|
||||
constexpr bool a_supported = tile_A::supported();
|
||||
constexpr bool b_supported = tile_B::supported();
|
||||
constexpr bool c_supported = tile_C::supported();
|
||||
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||
#else
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
constexpr bool supported = I_16_supported || I_32_supported;
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (!supported) {
|
||||
if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
@@ -248,6 +244,9 @@ static __global__ void mul_mat_f(
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif //VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
@@ -278,27 +277,24 @@ static __global__ void mul_mat_f_ids(
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<tile_B_I, 8, T> tile_B;
|
||||
typedef tile<16, tile_C_J, float> tile_C;
|
||||
|
||||
constexpr bool a_supported = tile_A::supported();
|
||||
constexpr bool b_supported = tile_B::supported();
|
||||
constexpr bool c_supported = tile_C::supported();
|
||||
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||
#else
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
constexpr bool supported = I_16_supported || I_32_supported;
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (!supported) {
|
||||
if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
@@ -517,6 +513,9 @@ static __global__ void mul_mat_f_ids(
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -35,20 +35,6 @@ typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
|
||||
ggml_metal_pipeline_t ggml_metal_pipeline_init(void);
|
||||
void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg);
|
||||
int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0);
|
||||
int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1);
|
||||
int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem);
|
||||
size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline);
|
||||
|
||||
// a collection of pipelines
|
||||
typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;
|
||||
|
||||
@@ -58,6 +44,19 @@ void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
|
||||
void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);
|
||||
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);
|
||||
|
||||
struct ggml_metal_pipeline_with_params {
|
||||
ggml_metal_pipeline_t pipeline;
|
||||
|
||||
int nsg;
|
||||
|
||||
int nr0;
|
||||
int nr1;
|
||||
|
||||
size_t smem;
|
||||
};
|
||||
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
|
||||
|
||||
//
|
||||
// MTLCommandBuffer wrapper
|
||||
//
|
||||
@@ -76,7 +75,7 @@ void ggml_metal_encoder_free(ggml_metal_encoder_t encoder);
|
||||
void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name);
|
||||
void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder);
|
||||
|
||||
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline);
|
||||
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline);
|
||||
|
||||
void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx);
|
||||
void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx);
|
||||
@@ -100,66 +99,67 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
|
||||
|
||||
void ggml_metal_library_free(ggml_metal_library_t lib);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
|
||||
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
int32_t ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
int32_t nqptg,
|
||||
int32_t ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
@@ -169,7 +169,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
bool has_kvpad,
|
||||
int32_t nsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
@@ -180,7 +180,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
int32_t nsg,
|
||||
int32_t nwg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
int32_t dv,
|
||||
|
||||
@@ -75,14 +75,6 @@ void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) {
|
||||
|
||||
struct ggml_metal_pipeline {
|
||||
id<MTLComputePipelineState> obj;
|
||||
|
||||
// suggested dispatch sizes
|
||||
int nsg;
|
||||
|
||||
int nr0;
|
||||
int nr1;
|
||||
|
||||
size_t smem;
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
|
||||
@@ -90,10 +82,6 @@ ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
|
||||
|
||||
*res = (struct ggml_metal_pipeline) {
|
||||
/*.obj =*/ nil,
|
||||
/*.nsg =*/ 0,
|
||||
/*.nr0 =*/ 0,
|
||||
/*.nr1 =*/ 0,
|
||||
/*.smem =*/ 0,
|
||||
};
|
||||
|
||||
return res;
|
||||
@@ -105,40 +93,8 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) {
|
||||
free(pipeline);
|
||||
}
|
||||
|
||||
void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) {
|
||||
pipeline->nsg = nsg;
|
||||
}
|
||||
|
||||
int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) {
|
||||
return pipeline->nsg;
|
||||
}
|
||||
|
||||
void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) {
|
||||
pipeline->nr0 = nr0;
|
||||
}
|
||||
|
||||
int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) {
|
||||
return pipeline->nr0;
|
||||
}
|
||||
|
||||
void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) {
|
||||
pipeline->nr1 = nr1;
|
||||
}
|
||||
|
||||
int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) {
|
||||
return pipeline->nr1;
|
||||
}
|
||||
|
||||
void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) {
|
||||
pipeline->smem = smem;
|
||||
}
|
||||
|
||||
size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) {
|
||||
return pipeline->smem;
|
||||
}
|
||||
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) {
|
||||
return pipeline->obj.maxTotalThreadsPerThreadgroup;
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) {
|
||||
return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
|
||||
}
|
||||
|
||||
struct ggml_metal_library {
|
||||
@@ -146,6 +102,8 @@ struct ggml_metal_library {
|
||||
id<MTLDevice> device;
|
||||
|
||||
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
|
||||
|
||||
NSLock * lock;
|
||||
};
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
@@ -296,9 +254,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
|
||||
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
||||
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -365,6 +324,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -380,26 +340,47 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
|
||||
|
||||
ggml_metal_pipelines_free(lib->pipelines);
|
||||
|
||||
[lib->lock release];
|
||||
|
||||
free(lib);
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
|
||||
return ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
|
||||
[lib->lock lock];
|
||||
|
||||
struct ggml_metal_pipeline_with_params res = {
|
||||
/*.pipeline =*/ nil,
|
||||
/*.nr0 =*/ 0,
|
||||
/*.nr1 =*/ 0,
|
||||
/*.nsg =*/ 0,
|
||||
/*.smem =*/ 0,
|
||||
};
|
||||
|
||||
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
|
||||
[lib->lock unlock];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
|
||||
// note: the pipelines are cached in the library per device, so they are shared across all metal contexts
|
||||
ggml_critical_section_start();
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
|
||||
struct ggml_metal_pipeline_with_params res = {
|
||||
/*.pipeline =*/ nil,
|
||||
/*.nr0 =*/ 0,
|
||||
/*.nr1 =*/ 0,
|
||||
/*.nsg =*/ 0,
|
||||
/*.smem =*/ 0,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
ggml_critical_section_end();
|
||||
[lib->lock lock];
|
||||
|
||||
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
if (res.pipeline) {
|
||||
[lib->lock unlock];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_pipeline_init();
|
||||
|
||||
@autoreleasepool {
|
||||
NSError * error = nil;
|
||||
|
||||
@@ -414,36 +395,53 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
|
||||
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
|
||||
}
|
||||
if (!mtl_function) {
|
||||
ggml_critical_section_end();
|
||||
[lib->lock unlock];
|
||||
|
||||
GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
|
||||
}
|
||||
|
||||
return nil;
|
||||
return res;
|
||||
}
|
||||
|
||||
res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
|
||||
id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
|
||||
|
||||
[mtl_function release];
|
||||
|
||||
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj,
|
||||
(int) res->obj.maxTotalThreadsPerThreadgroup,
|
||||
(int) res->obj.threadExecutionWidth);
|
||||
if (!obj) {
|
||||
[lib->lock unlock];
|
||||
|
||||
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
|
||||
ggml_critical_section_end();
|
||||
GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name);
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name,
|
||||
(void *) obj,
|
||||
(int) obj.maxTotalThreadsPerThreadgroup,
|
||||
(int) obj.threadExecutionWidth);
|
||||
|
||||
if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) {
|
||||
[obj release];
|
||||
|
||||
[lib->lock unlock];
|
||||
|
||||
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
|
||||
|
||||
return nil;
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipelines_add(lib->pipelines, name, res);
|
||||
res.pipeline = ggml_metal_pipeline_init();
|
||||
res.pipeline->obj = obj;
|
||||
|
||||
ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline);
|
||||
}
|
||||
|
||||
ggml_critical_section_end();
|
||||
[lib->lock unlock];
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -485,8 +483,8 @@ void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) {
|
||||
[encoder->obj popDebugGroup];
|
||||
}
|
||||
|
||||
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) {
|
||||
[encoder->obj setComputePipelineState:pipeline->obj];
|
||||
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) {
|
||||
[encoder->obj setComputePipelineState:pipeline.pipeline->obj];
|
||||
}
|
||||
|
||||
void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) {
|
||||
@@ -611,8 +609,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
|
||||
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
|
||||
dev->props.has_tensor = false;
|
||||
} else {
|
||||
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
||||
if (!ppl) {
|
||||
struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
||||
if (!ppl.pipeline) {
|
||||
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
|
||||
dev->props.has_tensor = false;
|
||||
}
|
||||
@@ -661,8 +659,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
|
||||
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
|
||||
dev->props.has_bfloat = false;
|
||||
} else {
|
||||
ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
||||
if (!ppl) {
|
||||
struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
|
||||
if (!ppl.pipeline) {
|
||||
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
|
||||
dev->props.has_bfloat = false;
|
||||
}
|
||||
@@ -820,6 +818,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
@@ -852,6 +852,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
@@ -869,6 +870,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SUM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_TRI:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_MEAN:
|
||||
|
||||
@@ -182,6 +182,10 @@ typedef struct {
|
||||
float bias;
|
||||
} ggml_metal_kargs_scale;
|
||||
|
||||
typedef struct {
|
||||
float val;
|
||||
} ggml_metal_kargs_fill;
|
||||
|
||||
typedef struct {
|
||||
float min;
|
||||
float max;
|
||||
@@ -831,6 +835,25 @@ typedef struct {
|
||||
float slope;
|
||||
} ggml_metal_kargs_leaky_relu;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_tri;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
|
||||
@@ -286,6 +286,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_scale(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
n_fuse = ggml_metal_op_fill(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_CLAMP:
|
||||
{
|
||||
n_fuse = ggml_metal_op_clamp(ctx, idx);
|
||||
@@ -414,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
n_fuse = ggml_metal_op_tri(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
|
||||
@@ -524,7 +532,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
|
||||
/*.dim =*/ dim,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -550,7 +558,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
||||
|
||||
ggml_metal_kargs_repeat args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -616,7 +624,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
// TODO: make a simpler cpy_bytes kernel
|
||||
|
||||
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ ne00,
|
||||
@@ -679,7 +687,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
/*.o1 =*/ { 0 },
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -721,7 +729,42 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const float val = ggml_get_op_params_f32(op, 0);
|
||||
|
||||
ggml_metal_kargs_fill args = {
|
||||
/*.val =*/ val
|
||||
};
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -760,7 +803,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -789,7 +832,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
|
||||
@@ -817,7 +860,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
|
||||
|
||||
const int32_t swp = ggml_get_op_params_i32(op, 1);
|
||||
const float alpha = ggml_get_op_params_f32(op, 2);
|
||||
@@ -870,7 +913,7 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
||||
/*.np =*/ n,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
@@ -925,7 +968,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
@@ -936,7 +979,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -963,7 +1006,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
||||
auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
||||
|
||||
int nth = 1;
|
||||
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
|
||||
@@ -1060,7 +1103,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
||||
auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
||||
|
||||
ggml_metal_kargs_cumsum_add args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1106,7 +1149,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
||||
|
||||
ggml_metal_kargs_get_rows args = {
|
||||
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
|
||||
@@ -1151,7 +1194,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
||||
|
||||
const int32_t nk0 = ne0/ggml_blck_size(op->type);
|
||||
|
||||
@@ -1252,7 +1295,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||
/*.n_head_log2 =*/ n_head_log2,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
@@ -1266,7 +1309,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
}
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||
@@ -1322,7 +1365,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb2 =*/ nb2,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||
@@ -1409,11 +1452,11 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb0 =*/ nb0,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
|
||||
|
||||
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -1426,7 +1469,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
||||
|
||||
@@ -1449,7 +1492,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
||||
const int64_t C = op->ne[0];
|
||||
const int64_t H = op->src[0]->ne[1];
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
|
||||
|
||||
int ida = 0;
|
||||
|
||||
@@ -1485,7 +1528,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
|
||||
|
||||
@@ -1592,7 +1635,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
||||
/* .np = */ np
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
|
||||
const int ntg = (np + nth - 1) / nth;
|
||||
@@ -1701,7 +1744,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ABORT("unsupported ne11");
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
|
||||
|
||||
ggml_metal_kargs_mul_mv_ext args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1748,7 +1791,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
// default: break;
|
||||
//}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
|
||||
|
||||
ggml_metal_kargs_mul_mm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1773,18 +1816,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
|
||||
} else {
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
||||
|
||||
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
||||
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
||||
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
||||
const int nr0 = pipeline.nr0;
|
||||
const int nr1 = pipeline.nr1;
|
||||
const int nsg = pipeline.nsg;
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_kargs_mul_mv args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1915,9 +1958,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
nb21,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
@@ -1938,7 +1981,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1967,20 +2010,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
|
||||
}
|
||||
} else {
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
|
||||
|
||||
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
||||
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
||||
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
||||
const int nr0 = pipeline.nr0;
|
||||
const int nr1 = pipeline.nr1;
|
||||
const int nsg = pipeline.nsg;
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_kargs_mul_mv_id args = {
|
||||
/*.nei0 =*/ ne20,
|
||||
@@ -2064,7 +2107,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb21 =*/ nb21,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -2308,7 +2351,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
@@ -2339,7 +2382,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb33 =*/ nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
|
||||
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
@@ -2424,7 +2467,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -2476,7 +2519,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
@@ -2578,7 +2621,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
||||
|
||||
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
@@ -2630,7 +2673,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
nrows,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
|
||||
auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
@@ -2762,7 +2805,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||
// the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
|
||||
bid_src1.offs = 0;
|
||||
|
||||
ggml_metal_pipeline_t pipeline = nullptr;
|
||||
struct ggml_metal_pipeline_with_params pipeline;
|
||||
|
||||
if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
@@ -2835,7 +2878,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
||||
|
||||
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
@@ -2844,7 +2887,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00/4);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
const int64_t nrows = ggml_nrows(op->src[0]);
|
||||
|
||||
@@ -2887,7 +2930,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
//while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
@@ -2897,7 +2940,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
//nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
//nth = std::min(nth, ne00/4);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3022,7 +3065,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
@@ -3033,7 +3076,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, args.ne00_t);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3127,7 +3170,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
||||
/* src2 =*/ op->src[2] != nullptr,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3199,7 +3242,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
||||
/*.KHW =*/ KH * KW,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
|
||||
|
||||
GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
@@ -3270,7 +3313,7 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
|
||||
/*.d1 =*/ d1,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
|
||||
|
||||
int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
|
||||
nth = std::min(nth, 256);
|
||||
@@ -3325,7 +3368,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb1 =*/ nb1,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3377,7 +3420,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb2 =*/ nb2,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3433,7 +3476,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||
/*.sf3 =*/ sf3
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
||||
|
||||
@@ -3477,7 +3520,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb3 =*/ nb3
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
|
||||
|
||||
const int nth = std::min(1024, ne0);
|
||||
|
||||
@@ -3523,7 +3566,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
||||
/*.p1 =*/ ((const int32_t *)(op->op_params))[1]
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
|
||||
|
||||
const int nth = std::min(1024, ne0);
|
||||
|
||||
@@ -3560,7 +3603,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
const int nth = std::min(1024, ne0);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3591,7 +3634,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
||||
/*.max_period =*/ max_period,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
|
||||
|
||||
const int nth = std::max(1, std::min(1024, dim/2));
|
||||
|
||||
@@ -3621,7 +3664,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb01 = */ nb01,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
|
||||
|
||||
const int64_t nrows = ggml_nrows(op->src[0]);
|
||||
|
||||
@@ -3630,7 +3673,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -3657,7 +3700,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
|
||||
|
||||
// bitonic sort requires the number of elements to be power of 2
|
||||
int nth = 1;
|
||||
@@ -3706,7 +3749,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
||||
auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
||||
|
||||
int len = nth;
|
||||
|
||||
@@ -3764,7 +3807,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
|
||||
|
||||
// bitonic sort requires the number of elements to be power of 2
|
||||
int nth = 1;
|
||||
@@ -3818,7 +3861,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
|
||||
auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
|
||||
|
||||
int len = args.top_k;
|
||||
|
||||
@@ -3881,7 +3924,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
/*.slope =*/ slope
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
@@ -3899,6 +3942,57 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_tri args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -3910,7 +4004,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_adamw args = {
|
||||
@@ -3946,7 +4040,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_sgd args = {
|
||||
|
||||
@@ -47,6 +47,7 @@ int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
@@ -83,6 +84,7 @@ int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
|
||||
@@ -1249,6 +1249,22 @@ kernel void kernel_scale_f32_4(
|
||||
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
||||
}
|
||||
|
||||
kernel void kernel_fill_f32(
|
||||
constant ggml_metal_kargs_fill & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
}
|
||||
|
||||
kernel void kernel_fill_f32_4(
|
||||
constant ggml_metal_kargs_fill & args,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
}
|
||||
|
||||
kernel void kernel_clamp_f32(
|
||||
constant ggml_metal_kargs_clamp & args,
|
||||
device const float * src0,
|
||||
@@ -1595,6 +1611,36 @@ kernel void kernel_exp_f32_4(
|
||||
dst[tpig] = exp(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_reglu_f32(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
@@ -1943,6 +1989,75 @@ typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
|
||||
|
||||
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
|
||||
|
||||
|
||||
template<uint32_t ttype>
|
||||
bool _ggml_vec_tri_cmp(const int i, const int r);
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
|
||||
return i < r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
|
||||
return i <= r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
|
||||
return i > r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
|
||||
return i >= r;
|
||||
}
|
||||
|
||||
template<typename T, int ttype>
|
||||
kernel void kernel_tri(
|
||||
constant ggml_metal_kargs_tri & args,
|
||||
device const char * src0,
|
||||
device const char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
// Each thread is a single element of the row if ne00 < max threads per
|
||||
// threadgroup, so this will loop once for each index that this thread is
|
||||
// responsible for
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
// Use the comparison as a mask for branchless
|
||||
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
|
||||
|
||||
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
|
||||
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
|
||||
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
|
||||
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
|
||||
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
|
||||
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
|
||||
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
|
||||
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
|
||||
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
|
||||
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
|
||||
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
|
||||
@@ -31,6 +31,14 @@ except ImportError:
|
||||
else:
|
||||
_mistral_common_installed = True
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
|
||||
get_one_valid_tokenizer_file,
|
||||
)
|
||||
except ImportError:
|
||||
# We still want the conversion to work with older mistral-common versions.
|
||||
get_one_valid_tokenizer_file = None
|
||||
|
||||
|
||||
import gguf
|
||||
|
||||
@@ -673,24 +681,30 @@ class MistralVocab(Vocab):
|
||||
|
||||
# Find the tokenizer files
|
||||
all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
|
||||
valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
|
||||
|
||||
if len(valid_tokenizer_files) == 0:
|
||||
raise ValueError(f"No tokenizer file found in the directory: {base_path}")
|
||||
# If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
|
||||
if len(valid_tokenizer_files) > 1:
|
||||
if "tekken.json" in valid_tokenizer_files:
|
||||
tokenizer_file = "tekken.json"
|
||||
else:
|
||||
tokenizer_file = sorted(valid_tokenizer_files)[-1]
|
||||
logger.warning(
|
||||
f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
|
||||
)
|
||||
if get_one_valid_tokenizer_file is not None:
|
||||
tokenizer_file_path = get_one_valid_tokenizer_file(all_files)
|
||||
else:
|
||||
tokenizer_file = valid_tokenizer_files[0]
|
||||
valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
|
||||
|
||||
if len(valid_tokenizer_files) == 0:
|
||||
raise ValueError(f"No tokenizer file found in the directory: {base_path}")
|
||||
# If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
|
||||
if len(valid_tokenizer_files) > 1:
|
||||
if "tekken.json" in valid_tokenizer_files:
|
||||
tokenizer_file = "tekken.json"
|
||||
else:
|
||||
tokenizer_file = sorted(valid_tokenizer_files)[-1]
|
||||
logger.warning(
|
||||
f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
|
||||
)
|
||||
else:
|
||||
tokenizer_file = valid_tokenizer_files[0]
|
||||
|
||||
tokenizer_file_path = base_path / tokenizer_file
|
||||
|
||||
self.tokenizer = MistralTokenizer.from_file(
|
||||
base_path / tokenizer_file
|
||||
tokenizer_file_path
|
||||
).instruct_tokenizer.tokenizer
|
||||
self.tokenizer_type = (
|
||||
MistralTokenizerType.tekken
|
||||
@@ -698,7 +712,7 @@ class MistralVocab(Vocab):
|
||||
else MistralTokenizerType.spm
|
||||
)
|
||||
self.vocab_size = self.tokenizer.n_words
|
||||
self.fname_tokenizer = base_path / tokenizer_file
|
||||
self.fname_tokenizer = tokenizer_file_path
|
||||
self._name = (
|
||||
"mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
|
||||
)
|
||||
|
||||
@@ -726,21 +726,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// sanity checks for models that have attention layers
|
||||
if (qs.n_attention_wv != 0 && !is_clip_model)
|
||||
{
|
||||
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
||||
// attention layers have a non-zero number of kv heads
|
||||
int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
||||
int32_t n_layer_all = model.hparams.n_layer;
|
||||
if (llama_model_has_encoder(&model)) {
|
||||
// now n_layer_attn is the number of attention layers in the encoder
|
||||
// now n_layer_all is the number of attention layers in the encoder
|
||||
// for each decoder block, there are 2 attention layers
|
||||
n_layer_attn += 2 * model.hparams.dec_n_layer;
|
||||
n_layer_all += 2 * model.hparams.dec_n_layer;
|
||||
}
|
||||
|
||||
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
|
||||
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w);
|
||||
LLAMA_LOG_INFO("%s: n_layer_all = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_all, n_layer_recr, pruned_attention_w);
|
||||
|
||||
GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
|
||||
GGML_ASSERT((qs.n_attention_wv == n_layer_all - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
|
||||
}
|
||||
|
||||
size_t total_size_org = 0;
|
||||
|
||||
1
tests/.gitignore
vendored
1
tests/.gitignore
vendored
@@ -3,3 +3,4 @@
|
||||
*.o
|
||||
ggml-common.h
|
||||
**/*.swp
|
||||
!peg-parser
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
llama_add_compile_flags()
|
||||
|
||||
function(llama_build source)
|
||||
set(TEST_SOURCES ${source} ${ARGN})
|
||||
|
||||
if (DEFINED LLAMA_TEST_NAME)
|
||||
set(TEST_TARGET ${LLAMA_TEST_NAME})
|
||||
else()
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
endif()
|
||||
|
||||
add_executable(${TEST_TARGET} ${source})
|
||||
add_executable(${TEST_TARGET} ${TEST_SOURCES})
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE common)
|
||||
install(TARGETS ${TEST_TARGET} RUNTIME)
|
||||
endfunction()
|
||||
@@ -83,6 +85,8 @@ function(llama_build_and_test source)
|
||||
set(multiValueArgs ARGS)
|
||||
cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
set(TEST_SOURCES ${source} ${LLAMA_TEST_UNPARSED_ARGUMENTS} get-model.cpp)
|
||||
|
||||
if (NOT DEFINED LLAMA_TEST_LABEL)
|
||||
set(LLAMA_TEST_LABEL "main")
|
||||
endif()
|
||||
@@ -95,7 +99,7 @@ function(llama_build_and_test source)
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
endif()
|
||||
|
||||
add_executable(${TEST_TARGET} ${source} get-model.cpp)
|
||||
add_executable(${TEST_TARGET} ${TEST_SOURCES})
|
||||
install(TARGETS ${TEST_TARGET} RUNTIME)
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE common)
|
||||
|
||||
@@ -180,9 +184,21 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
|
||||
endif()
|
||||
|
||||
llama_build_and_test(test-chat-parser.cpp)
|
||||
llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp)
|
||||
llama_build_and_test(test-chat-template.cpp)
|
||||
llama_build_and_test(test-json-partial.cpp)
|
||||
llama_build_and_test(test-log.cpp)
|
||||
llama_build_and_test(
|
||||
test-peg-parser.cpp
|
||||
peg-parser/simple-tokenize.cpp
|
||||
peg-parser/test-basic.cpp
|
||||
peg-parser/test-gbnf-generation.cpp
|
||||
peg-parser/test-json-parser.cpp
|
||||
peg-parser/test-json-serialization.cpp
|
||||
peg-parser/test-unicode.cpp
|
||||
peg-parser/testing.h
|
||||
peg-parser/tests.h
|
||||
)
|
||||
llama_build_and_test(test-regex-partial.cpp)
|
||||
|
||||
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
||||
|
||||
37
tests/peg-parser/simple-tokenize.cpp
Normal file
37
tests/peg-parser/simple-tokenize.cpp
Normal file
@@ -0,0 +1,37 @@
|
||||
#include "simple-tokenize.h"
|
||||
|
||||
std::vector<std::string> simple_tokenize(const std::string & input) {
|
||||
std::vector<std::string> result;
|
||||
std::string current;
|
||||
|
||||
for (size_t i = 0; i < input.size(); i++) {
|
||||
switch (input[i]) {
|
||||
case ' ':
|
||||
case '\n':
|
||||
case '\t':
|
||||
case '{':
|
||||
case '}':
|
||||
case ',':
|
||||
case '[':
|
||||
case '"':
|
||||
case ']':
|
||||
case '.':
|
||||
case '<':
|
||||
case '>':
|
||||
case '=':
|
||||
case '/':
|
||||
if (!current.empty()) {
|
||||
result.push_back(current);
|
||||
current.clear();
|
||||
}
|
||||
default:;
|
||||
}
|
||||
current += input[i];
|
||||
}
|
||||
|
||||
if (!current.empty()) {
|
||||
result.push_back(current);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
6
tests/peg-parser/simple-tokenize.h
Normal file
6
tests/peg-parser/simple-tokenize.h
Normal file
@@ -0,0 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
std::vector<std::string> simple_tokenize(const std::string &);
|
||||
454
tests/peg-parser/test-basic.cpp
Normal file
454
tests/peg-parser/test-basic.cpp
Normal file
@@ -0,0 +1,454 @@
|
||||
#include "tests.h"
|
||||
|
||||
void test_basic(testing & t) {
|
||||
t.test("chars", [](testing & t) {
|
||||
// Test common escape sequences - newline
|
||||
t.test("escape_sequence_newline", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("\n");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_newline", true, result.success());
|
||||
});
|
||||
|
||||
// Test common escape sequences - tab
|
||||
t.test("escape_sequence_tab", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("\t");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_tab", true, result.success());
|
||||
});
|
||||
|
||||
// Test common escape sequences - backslash
|
||||
t.test("escape_sequence_backslash", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("\\");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_backslash", true, result.success());
|
||||
});
|
||||
|
||||
// Test common escape sequences - space (should ())
|
||||
t.test("escape_sequence_space_fail", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context(" ");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_space_fail", true, result.fail());
|
||||
});
|
||||
|
||||
// Test escaped dash - 'a' should succeed
|
||||
t.test("escaped_dash_a", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("a");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_a", true, result.success());
|
||||
});
|
||||
|
||||
// Test escaped dash - '-' should succeed (literal dash)
|
||||
t.test("escaped_dash_literal", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("-");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_literal", true, result.success());
|
||||
});
|
||||
|
||||
// Test escaped dash - 'z' should succeed
|
||||
t.test("escaped_dash_z", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("z");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_z", true, result.success());
|
||||
});
|
||||
|
||||
// Test escaped dash - 'b' should NOT match (since \- is literal dash, not range)
|
||||
t.test("escaped_dash_b_fail", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("b");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_b_fail", true, result.fail());
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
t.test("optional", [](testing & t) {
|
||||
// Full match with optional part present
|
||||
t.test("optional_present", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto ctx = common_peg_parse_context("hello world");
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("optional_present", true, result.success());
|
||||
t.assert_equal("optional_present_end", 11u, result.end);
|
||||
});
|
||||
|
||||
// Full match with optional part absent
|
||||
t.test("optional_absent", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto ctx = common_peg_parse_context("hello", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("optional_absent", true, result.success());
|
||||
t.assert_equal("optional_absent_end", 5u, result.end);
|
||||
});
|
||||
|
||||
// Partial match - waiting for more input to determine if optional matches
|
||||
t.test("partial_match_need_more", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto ctx = common_peg_parse_context("hello ", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("partial_match_need_more", true, result.need_more_input());
|
||||
});
|
||||
});
|
||||
|
||||
t.test("partial parsing", [](testing & t) {
|
||||
// Literals - Basic Success
|
||||
t.test("literal_success", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("hello"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("hello");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("literal_success", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes - Basic Lowercase Success
|
||||
t.test("char_class_lowercase_success", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("a");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_lowercase_success", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes - Uppercase Fail
|
||||
t.test("char_class_uppercase_fail", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("A");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_uppercase_fail", true, result.fail());
|
||||
});
|
||||
|
||||
// Char Classes with Dash - Lowercase Success
|
||||
t.test("char_class_with_dash_lowercase", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("f");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_with_dash_lowercase", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes with Dash - Literal Dash Success
|
||||
t.test("char_class_with_dash_literal_dash", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("-");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_with_dash_literal_dash", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes with Dash - Uppercase Fail
|
||||
t.test("char_class_with_dash_uppercase_fail", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("A");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_with_dash_uppercase_fail", true, result.fail());
|
||||
});
|
||||
|
||||
// Sequences - Partial Match 1
|
||||
t.test("sequence_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("<thi", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Sequences - Partial Match 2
|
||||
t.test("sequence_partial_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("begin") + p.literal("end"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("begin", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Sequences - Partial Match 3
|
||||
t.test("sequence_partial_match_3", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("<think></", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_partial_match_3", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Sequences - Full Match
|
||||
t.test("sequence_full_match", [&](testing & t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("hello") + p.literal("world"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("helloworld", false);
|
||||
auto result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("sequence_full_match", true, result.success());
|
||||
});
|
||||
|
||||
// Sequences - No Match
|
||||
t.test("sequence_no_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("<think>I am common_chat_combinator_parser", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_no_match", true, result.fail());
|
||||
});
|
||||
|
||||
// Choices - Partial Match 1
|
||||
t.test("choices_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("option1") | p.literal("option2"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("opt", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Choices - Partial Match 2
|
||||
t.test("choices_partial_match_2", [&](testing & t) {
|
||||
auto parser =
|
||||
build_peg_parser([](common_peg_parser_builder & p) { return p.literal("choice_a") | p.literal("choice_b"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("choice", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Choices - Full Match 1
|
||||
t.test("choices_full_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("first") | p.literal("second"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("first", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_full_match_1", true, result.success());
|
||||
});
|
||||
|
||||
// Choices - Full Match 2
|
||||
t.test("choices_full_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("alpha") | p.literal("beta"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("beta", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_full_match_2", true, result.success());
|
||||
});
|
||||
|
||||
// Choices - No Match
|
||||
t.test("choices_no_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("good") | p.literal("better"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("best", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_no_match", true, result.fail());
|
||||
});
|
||||
|
||||
// Zero or More - Partial Match 1
|
||||
t.test("zero_or_more_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("ab")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("a", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("zero_or_more_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Zero or More - Partial Match 2
|
||||
t.test("zero_or_more_partial_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("xy")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("xyx", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("zero_or_more_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Zero or More - Full Match
|
||||
t.test("zero_or_more_full_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("test")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("test", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("zero_or_more_full_match", true, result.success());
|
||||
});
|
||||
|
||||
// One or More - Partial Match 1
|
||||
t.test("one_or_more_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("repeat")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("rep", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// One or More - Partial Match 2
|
||||
t.test("one_or_more_partial_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("ab")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("aba", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// One or More - Full Match
|
||||
t.test("one_or_more_full_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("single")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("single", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_full_match", true, result.success());
|
||||
});
|
||||
|
||||
// One or More - No Match
|
||||
t.test("one_or_more_no_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("()")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("success", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_no_match", true, result.fail());
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
t.test("recursive rules", [](testing &t) {
|
||||
// Test simple number
|
||||
t.test("simple_number", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("1", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test simple list
|
||||
t.test("simple_list", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[1]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test nested list
|
||||
t.test("nested_list", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[[2]]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test deeply nested list
|
||||
t.test("deeply_nested_list", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[[[3]]]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test need_more_input match
|
||||
t.test("need_more_input_match", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[[", true);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Test no match
|
||||
t.test("no_match", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[a]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_fail", true, result.fail());
|
||||
});
|
||||
});
|
||||
}
|
||||
250
tests/peg-parser/test-gbnf-generation.cpp
Normal file
250
tests/peg-parser/test-gbnf-generation.cpp
Normal file
@@ -0,0 +1,250 @@
|
||||
#include "tests.h"
|
||||
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
||||
#include <regex>
|
||||
|
||||
static std::string trim_leading_space(const std::string & s) {
|
||||
static const std::regex leading_ws_re = std::regex(R"((^|\n)\s+)");
|
||||
return std::regex_replace(s, leading_ws_re, "$1");
|
||||
}
|
||||
|
||||
static void assert_gbnf_equal(testing & t, const std::string & expected, const std::string & actual) {
|
||||
t.assert_equal("gbnf are equal", trim_leading_space(expected), trim_leading_space(actual));
|
||||
}
|
||||
|
||||
void test_gbnf_generation(testing &t) {
|
||||
t.test("literal grammar generation", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("char class grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.chars("[a-z]", 1, 1);
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= [a-z]
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("sequence grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.literal(" ") + p.literal("world");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello" " " "world"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("choice grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("cat") | p.literal("dog");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "cat" | "dog"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("one_or_more grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.one_or_more(p.literal("a"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a"+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("zero_or_more grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.zero_or_more(p.literal("a"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a"*
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("optional grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello" " world"?
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("until grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= ([^<] | "<" [^/] | "</" [^t] | "</t" [^a] | "</ta" [^g] | "</tag" [^>])*
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("complex expressions with parentheses", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.one_or_more(p.literal("a") | p.literal("b"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= ("a" | "b")+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("rule references", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
auto digit = p.rule("digit", p.chars("[0-9]", 1, 1));
|
||||
return p.one_or_more(digit);
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
digit ::= [0-9]
|
||||
root ::= digit+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("escaping in literals", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello\nworld\n!");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello\nworld\n!"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("operator<< (whitespace insertion)", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") << p.literal("world");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello" space "world"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("emit only reachable rules", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("orphan", p.literal("orphan"));
|
||||
return p.literal("hello") + p.rule("child", p.literal(" world"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
child ::= " world"
|
||||
root ::= "hello" child
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("emit only trigger rules (and references)", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
auto rule1 = p.rule("rule-1", p.literal("a") + p.ref("rule-2"));
|
||||
p.rule("rule-2", p.literal("b") + p.ref("rule-3"), true);
|
||||
p.rule("rule-3", p.literal("c") + p.ref("rule-4"));
|
||||
p.rule("rule-4", p.literal("d"), true);
|
||||
return rule1;
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= rule-1
|
||||
rule-1 ::= "a" rule-2
|
||||
rule-2 ::= "b" rule-3
|
||||
rule-3 ::= "c" rule-4
|
||||
rule-4 ::= "d"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
|
||||
auto gbnf_lazy = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder, true);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= rule-2 | rule-4
|
||||
rule-2 ::= "b" rule-3
|
||||
rule-3 ::= "c" rule-4
|
||||
rule-4 ::= "d"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf_lazy);
|
||||
});
|
||||
}
|
||||
109
tests/peg-parser/test-json-parser.cpp
Normal file
109
tests/peg-parser/test-json-parser.cpp
Normal file
@@ -0,0 +1,109 @@
|
||||
#include "tests.h"
|
||||
|
||||
void test_json_parser(testing &t) {
|
||||
// Test parsing a simple JSON object
|
||||
t.test("simple JSON object parsing", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"({"name": "test", "value": 42, "flag": true})";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test parsing a JSON array with mixed types
|
||||
t.test("JSON array with mixed types", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"([1, "hello", true, null, 3.14])";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test parsing nested JSON with objects and arrays
|
||||
t.test("nested JSON with objects and arrays", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input =
|
||||
R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test need_more_input() parsing - incomplete object
|
||||
t.test("need_more_input() parsing - incomplete object", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"({"name": "test", "value": )";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Test need_more_input() parsing - incomplete array
|
||||
t.test("need_more_input() parsing - incomplete array", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"([1, 2, 3, )";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Test need_more_input() parsing - incomplete nested structure
|
||||
t.test("need_more_input() parsing - incomplete nested structure", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"({"data": {"nested": )";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
t.test("object member", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.json_member("name", "\"" + p.chars("[a-z]") + "\"");
|
||||
});
|
||||
|
||||
t.test("success", [&](testing &t) {
|
||||
std::string input = R"("name": "bob")";
|
||||
common_peg_parse_context ctx(input, false);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
});
|
||||
|
||||
t.test("partial", [&](testing &t) {
|
||||
std::string input = R"("name": "bo)";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("need more input", result.need_more_input());
|
||||
});
|
||||
|
||||
t.test("failed", [&](testing &t) {
|
||||
std::string input = R"([])";
|
||||
common_peg_parse_context ctx(input, false);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("fail", result.fail());
|
||||
});
|
||||
});
|
||||
}
|
||||
28
tests/peg-parser/test-json-serialization.cpp
Normal file
28
tests/peg-parser/test-json-serialization.cpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#include "tests.h"
|
||||
|
||||
void test_json_serialization(testing &t) {
|
||||
auto original = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return "<tool_call>" + p.json() + "</tool_call>";
|
||||
});
|
||||
|
||||
auto json_serialized = original.to_json().dump();
|
||||
|
||||
t.test("compare before/after", [&](testing &t) {
|
||||
auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized));
|
||||
|
||||
// Test complex JSON
|
||||
std::string input = R"({"name": "test", "values": [1, 2, 3], "nested": {"a": true}})";
|
||||
common_peg_parse_context ctx1(input);
|
||||
common_peg_parse_context ctx2(input);
|
||||
|
||||
auto result1 = original.parse(ctx1);
|
||||
auto result2 = deserialized.parse(ctx2);
|
||||
|
||||
t.assert_equal("both_succeed", result1.success(), result2.success());
|
||||
t.assert_equal("same_end_pos", result1.end, result2.end);
|
||||
});
|
||||
|
||||
t.bench("deserialize", [&]() {
|
||||
auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized));
|
||||
}, 100);
|
||||
}
|
||||
449
tests/peg-parser/test-unicode.cpp
Normal file
449
tests/peg-parser/test-unicode.cpp
Normal file
@@ -0,0 +1,449 @@
|
||||
#include "tests.h"
|
||||
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include <cctype>
|
||||
|
||||
static void assert_result_equal(testing & t, common_peg_parse_result_type expected, common_peg_parse_result_type actual) {
|
||||
t.assert_equal(common_peg_parse_result_type_name(expected), common_peg_parse_result_type_name(actual));
|
||||
}
|
||||
|
||||
static std::string hex_dump(const std::string& str) {
|
||||
std::ostringstream oss;
|
||||
for (unsigned char c : str) {
|
||||
if (std::isprint(c)) {
|
||||
oss << c;
|
||||
} else {
|
||||
oss << "\\x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(c);
|
||||
}
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
void test_unicode(testing &t) {
|
||||
struct test_case {
|
||||
std::string input;
|
||||
std::string expected_text;
|
||||
common_peg_parse_result_type expected_result;
|
||||
};
|
||||
|
||||
t.test("any", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Valid UTF-8 sequences
|
||||
{"Hello", "Hello", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
{std::string("Caf\xC3\xA9"), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
{std::string("\xF0\x9F\x9A\x80"), std::string("\xF0\x9F\x9A\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Incomplete UTF-8 sequences (partial bytes at end)
|
||||
{std::string("Caf\xC3"), "Caf", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
{std::string("\xE4\xBD"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
{std::string("\xF0\x9F\x9A"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Invalid/malformed UTF-8 sequences
|
||||
{std::string("\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
{std::string("Hello\x80World"), "Hello", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.one_or_more(p.any()), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("char classes", [](testing &t) {
|
||||
t.test("unicode range U+4E00-U+9FFF (CJK)", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Within range - CJK Unified Ideographs
|
||||
{std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00
|
||||
{std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60
|
||||
{std::string("\xE5\xA5\xBD"), std::string("\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+597D
|
||||
{std::string("\xE9\xBF\xBF"), std::string("\xE9\xBF\xBF"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+9FFF
|
||||
|
||||
// Outside range - should fail
|
||||
{"a", "", COMMON_PEG_PARSE_RESULT_FAIL}, // ASCII
|
||||
{std::string("\xE4\xB7\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+4DFF (before range)
|
||||
{std::string("\xEA\x80\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+A000 (after range)
|
||||
|
||||
// Incomplete sequences in range
|
||||
{std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+4E00
|
||||
{std::string("\xE5\xA5"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+597D
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.chars(R"([\u4E00-\u9FFF])"), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("unicode range U+1F600-U+1F64F (emoticons)", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Within range - Emoticons (all 4-byte UTF-8)
|
||||
{std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600
|
||||
{std::string("\xF0\x9F\x98\x81"), std::string("\xF0\x9F\x98\x81"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F601
|
||||
{std::string("\xF0\x9F\x99\x8F"), std::string("\xF0\x9F\x99\x8F"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F64F
|
||||
|
||||
// Outside range
|
||||
{std::string("\xF0\x9F\x97\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F5FF (before range)
|
||||
{std::string("\xF0\x9F\x99\x90"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F650 (after range)
|
||||
{std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680 (outside range)
|
||||
|
||||
// Incomplete sequences
|
||||
{std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete emoji
|
||||
{std::string("\xF0\x9F"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Very incomplete
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.chars(R"([\U0001F600-\U0001F64F])"), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("mixed unicode ranges", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Match CJK
|
||||
{std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00
|
||||
{std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60
|
||||
|
||||
// Match emoticons
|
||||
{std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600
|
||||
|
||||
// Match ASCII digits
|
||||
{"5", "5", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Don't match outside any range
|
||||
{"a", "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
{std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680
|
||||
|
||||
// Incomplete
|
||||
{std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
{std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.chars(R"([\u4E00-\u9FFF\U0001F600-\U0001F64F0-9])"), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
t.test("until parser", [](testing &t) {
|
||||
t.test("ASCII delimiter with Unicode content", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// CJK characters before delimiter
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD</tag>"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Emoji before delimiter
|
||||
{std::string("\xF0\x9F\x98\x80</tag>"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Mixed content
|
||||
{std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!</tag>"), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("incomplete UTF-8 at end", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Incomplete emoji at end, no delimiter
|
||||
{std::string("content\xF0\x9F\x98"), std::string("content"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete CJK at end, no delimiter
|
||||
{std::string("hello\xE4\xB8"), std::string("hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Complete content, no delimiter (should consume all valid UTF-8)
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("malformed UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Invalid UTF-8 bytes
|
||||
{std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Continuation byte without lead byte
|
||||
{std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Invalid continuation byte
|
||||
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
t.test("json_string parser", [](testing &t) {
|
||||
t.test("valid UTF-8 characters", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// ASCII only
|
||||
{"Hello World\"", "Hello World", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// 2-byte UTF-8 (accented characters)
|
||||
{std::string("Caf\xC3\xA9\""), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// 3-byte UTF-8 (CJK)
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// 4-byte UTF-8 (emoji)
|
||||
{std::string("\xF0\x9F\x98\x80\""), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Mixed content
|
||||
{std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!\""), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.json_string_content(), p.literal("\"")});
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("incomplete UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Incomplete 2-byte sequence
|
||||
{std::string("Caf\xC3"), std::string("Caf"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete 3-byte sequence
|
||||
{std::string("Hello\xE4\xB8"), std::string("Hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete 4-byte sequence
|
||||
{std::string("Text\xF0\x9F\x98"), std::string("Text"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete at very start
|
||||
{std::string("\xE4\xBD"), std::string(""), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.json_string_content();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("malformed UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Invalid UTF-8 bytes
|
||||
{std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Continuation byte without lead byte
|
||||
{std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Invalid continuation byte
|
||||
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Overlong encoding (security issue)
|
||||
{std::string("\xC0\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.json_string_content();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("escape sequences with UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Unicode escape sequence
|
||||
{"Hello\\u0041\"", "Hello\\u0041", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Mix of UTF-8 and escape sequences
|
||||
{std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Escaped quote in UTF-8 string
|
||||
{std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.json_string_content(), p.literal("\"")});
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
243
tests/peg-parser/testing.h
Normal file
243
tests/peg-parser/testing.h
Normal file
@@ -0,0 +1,243 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include <vector>
|
||||
|
||||
struct testing {
|
||||
std::ostream &out;
|
||||
std::vector<std::string> stack;
|
||||
std::regex filter;
|
||||
bool filter_tests = false;
|
||||
bool throw_exception = false;
|
||||
bool verbose = false;
|
||||
int tests = 0;
|
||||
int assertions = 0;
|
||||
int failures = 0;
|
||||
int unnamed = 0;
|
||||
int exceptions = 0;
|
||||
|
||||
static constexpr std::size_t status_column = 80;
|
||||
|
||||
explicit testing(std::ostream &os = std::cout) : out(os) {}
|
||||
|
||||
std::string indent() const {
|
||||
if (stack.empty()) {
|
||||
return "";
|
||||
}
|
||||
return std::string((stack.size() - 1) * 2, ' ');
|
||||
}
|
||||
|
||||
std::string full_name() const {
|
||||
return string_join(stack, ".");
|
||||
}
|
||||
|
||||
void log(const std::string & msg) {
|
||||
if (verbose) {
|
||||
out << indent() << " " << msg << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void set_filter(const std::string & re) {
|
||||
filter = std::regex(re);
|
||||
filter_tests = true;
|
||||
}
|
||||
|
||||
bool should_run() const {
|
||||
if (filter_tests) {
|
||||
if (!std::regex_match(full_name(), filter)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void run_with_exceptions(F &&f, const char *ctx) {
|
||||
try {
|
||||
f();
|
||||
} catch (const std::exception &e) {
|
||||
++failures;
|
||||
++exceptions;
|
||||
out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): " << e.what() << "\n";
|
||||
if (throw_exception) {
|
||||
throw;
|
||||
}
|
||||
} catch (...) {
|
||||
++failures;
|
||||
++exceptions;
|
||||
out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): unknown\n";
|
||||
if (throw_exception) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void print_result(const std::string &label, int new_failures, int new_assertions, const std::string &extra = "") const {
|
||||
std::string line = indent() + label;
|
||||
|
||||
std::string details;
|
||||
if (new_assertions > 0) {
|
||||
if (new_failures == 0) {
|
||||
details = std::to_string(new_assertions) + " assertion(s)";
|
||||
} else {
|
||||
details = std::to_string(new_failures) + " of " +
|
||||
std::to_string(new_assertions) + " assertion(s) failed";
|
||||
}
|
||||
}
|
||||
if (!extra.empty()) {
|
||||
if (!details.empty()) {
|
||||
details += ", ";
|
||||
}
|
||||
details += extra;
|
||||
}
|
||||
|
||||
if (!details.empty()) {
|
||||
line += " (" + details + ")";
|
||||
}
|
||||
|
||||
std::string status = (new_failures == 0) ? "[PASS]" : "[FAIL]";
|
||||
|
||||
if (line.size() + 1 < status_column) {
|
||||
line.append(status_column - line.size(), ' ');
|
||||
} else {
|
||||
line.push_back(' ');
|
||||
}
|
||||
|
||||
out << line << status << "\n";
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void test(const std::string &name, F f) {
|
||||
stack.push_back(name);
|
||||
if (!should_run()) {
|
||||
stack.pop_back();
|
||||
return;
|
||||
}
|
||||
|
||||
++tests;
|
||||
out << indent() << name << "\n";
|
||||
|
||||
int before_failures = failures;
|
||||
int before_assertions = assertions;
|
||||
|
||||
run_with_exceptions([&] { f(*this); }, "test");
|
||||
|
||||
int new_failures = failures - before_failures;
|
||||
int new_assertions = assertions - before_assertions;
|
||||
|
||||
print_result(name, new_failures, new_assertions);
|
||||
|
||||
stack.pop_back();
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void test(F f) {
|
||||
test("test #" + std::to_string(++unnamed), f);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void bench(const std::string &name, F f, int iterations = 100) {
|
||||
stack.push_back(name);
|
||||
if (!should_run()) {
|
||||
stack.pop_back();
|
||||
return;
|
||||
}
|
||||
|
||||
++tests;
|
||||
out << indent() << "[bench] " << name << "\n";
|
||||
|
||||
int before_failures = failures;
|
||||
int before_assertions = assertions;
|
||||
|
||||
using clock = std::chrono::high_resolution_clock;
|
||||
|
||||
std::chrono::microseconds duration(0);
|
||||
|
||||
run_with_exceptions([&] {
|
||||
for (auto i = 0; i < iterations; i++) {
|
||||
auto start = clock::now();
|
||||
f();
|
||||
duration += std::chrono::duration_cast<std::chrono::microseconds>(clock::now() - start);
|
||||
}
|
||||
}, "bench");
|
||||
|
||||
auto avg_elapsed = duration.count() / iterations;
|
||||
auto avg_elapsed_s = std::chrono::duration_cast<std::chrono::duration<double>>(duration).count() / iterations;
|
||||
auto rate = (avg_elapsed_s > 0.0) ? (1.0 / avg_elapsed_s) : 0.0;
|
||||
|
||||
int new_failures = failures - before_failures;
|
||||
int new_assertions = assertions - before_assertions;
|
||||
|
||||
std::string extra =
|
||||
"n=" + std::to_string(iterations) +
|
||||
" avg=" + std::to_string(avg_elapsed) + "us" +
|
||||
" rate=" + std::to_string(int(rate)) + "/s";
|
||||
|
||||
print_result("[bench] " + name, new_failures, new_assertions, extra);
|
||||
|
||||
stack.pop_back();
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void bench(F f, int iterations = 100) {
|
||||
bench("bench #" + std::to_string(++unnamed), f, iterations);
|
||||
}
|
||||
|
||||
// Assertions
|
||||
bool assert_true(bool cond) {
|
||||
return assert_true("", cond);
|
||||
}
|
||||
|
||||
bool assert_true(const std::string &msg, bool cond) {
|
||||
++assertions;
|
||||
if (!cond) {
|
||||
++failures;
|
||||
out << indent() << "ASSERT TRUE FAILED";
|
||||
if (!msg.empty()) {
|
||||
out << " : " << msg;
|
||||
}
|
||||
out << "\n";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename A, typename B>
|
||||
bool assert_equal(const A &expected, const B &actual) {
|
||||
return assert_equal("", expected, actual);
|
||||
}
|
||||
|
||||
template <typename A, typename B>
|
||||
bool assert_equal(const std::string &msg, const A &expected, const B &actual) {
|
||||
++assertions;
|
||||
if (!(actual == expected)) {
|
||||
++failures;
|
||||
out << indent() << "ASSERT EQUAL FAILED";
|
||||
if (!msg.empty()) {
|
||||
out << " : " << msg;
|
||||
}
|
||||
out << "\n";
|
||||
|
||||
out << indent() << " expected: " << expected << "\n";
|
||||
out << indent() << " actual : " << actual << "\n";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Print summary and return an exit code
|
||||
int summary() const {
|
||||
out << "\n";
|
||||
out << "tests : " << tests << "\n";
|
||||
out << "assertions : " << assertions << "\n";
|
||||
out << "failures : " << failures << "\n";
|
||||
out << "exceptions : " << exceptions << "\n";
|
||||
return failures == 0 ? 0 : 1;
|
||||
}
|
||||
};
|
||||
24
tests/peg-parser/tests.h
Normal file
24
tests/peg-parser/tests.h
Normal file
@@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
// Common includes for all test files
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "testing.h"
|
||||
#include "peg-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "simple-tokenize.h"
|
||||
|
||||
struct bench_tool_call {
|
||||
std::string id;
|
||||
std::string name;
|
||||
nlohmann::ordered_json args;
|
||||
};
|
||||
|
||||
// Test function declarations
|
||||
void test_basic(testing &t);
|
||||
void test_json_parser(testing &t);
|
||||
void test_gbnf_generation(testing &t);
|
||||
void test_unicode(testing &t);
|
||||
void test_json_serialization(testing &t);
|
||||
768
tests/test-chat-peg-parser.cpp
Normal file
768
tests/test-chat-peg-parser.cpp
Normal file
@@ -0,0 +1,768 @@
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
#include "chat-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "peg-parser.h"
|
||||
#include "peg-parser/testing.h"
|
||||
#include "peg-parser/simple-tokenize.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static json create_tools();
|
||||
static void test_example_native(testing & t);
|
||||
static void test_example_qwen3_coder(testing & t);
|
||||
static void test_command7_parser_compare(testing & t);
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
testing t(std::cout);
|
||||
if (argc >= 2) {
|
||||
t.set_filter(argv[1]);
|
||||
}
|
||||
|
||||
const char * verbose = getenv("LLAMA_TEST_VERBOSE");
|
||||
if (verbose) {
|
||||
t.verbose = std::string(verbose) == "1";
|
||||
}
|
||||
|
||||
t.test("native", test_example_native);
|
||||
t.test("qwen3 coder", test_example_qwen3_coder);
|
||||
t.test("comparison", test_command7_parser_compare);
|
||||
|
||||
return t.summary();
|
||||
}
|
||||
|
||||
static json create_tools() {
|
||||
json tools = json::array();
|
||||
|
||||
json tool_weather = {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "get_current_weather"},
|
||||
{"description", "Get the current weather in a given location"},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"location", {
|
||||
{"type", "string"},
|
||||
{"description", "The city and state, e.g. San Francisco, CA"}
|
||||
}},
|
||||
{"unit", {
|
||||
{"type", "string"},
|
||||
{"enum", {"celsius", "fahrenheit"}},
|
||||
{"description", "The temperature unit to use. Infer this from the users location."}
|
||||
}}
|
||||
}},
|
||||
{"required", {"location", "unit"}},
|
||||
}},
|
||||
}}
|
||||
};
|
||||
tools.push_back(tool_weather);
|
||||
|
||||
json tool_forecast = {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "get_forecast"},
|
||||
{"description", "Get the weather forecast for a given location"},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"location", {
|
||||
{"type", "string"},
|
||||
{"description", "The city and state, e.g. San Francisco, CA"}
|
||||
}},
|
||||
{"unit", {
|
||||
{"type", "string"},
|
||||
{"enum", {"celsius", "fahrenheit"}},
|
||||
{"description", "The temperature unit to use. Infer this from the users location."}
|
||||
}},
|
||||
{"days", {
|
||||
{"type", "integer"},
|
||||
{"description", "Number of days to forecast (1-10)"},
|
||||
{"minimum", 1},
|
||||
{"maximum", 10}
|
||||
}}
|
||||
}},
|
||||
{"required", {"location", "unit"}},
|
||||
}},
|
||||
}}
|
||||
};
|
||||
tools.push_back(tool_forecast);
|
||||
|
||||
json tool_search = {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "search_knowledge_base"},
|
||||
{"description", "Search the internal technical documentation knowledge base."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"query", {
|
||||
{"type", "string"},
|
||||
{"description", "The search query string."}
|
||||
}},
|
||||
{"max_results", {
|
||||
{"type", "integer"},
|
||||
{"description", "The maximum number of results to return."},
|
||||
{"default", 5}
|
||||
}},
|
||||
{"category", {
|
||||
{"type", "string"},
|
||||
{"enum", {"api", "troubleshooting", "billing", "general"}},
|
||||
{"description", "Filter search by specific category."}
|
||||
}}
|
||||
}},
|
||||
{"required", {"query", "category"}},
|
||||
{"additionalProperties", false}
|
||||
}},
|
||||
{"strict", true}
|
||||
}}
|
||||
};
|
||||
tools.push_back(tool_search);
|
||||
|
||||
return tools;
|
||||
}
|
||||
|
||||
struct tool_argument {
|
||||
std::string name;
|
||||
std::string type;
|
||||
bool is_required;
|
||||
json schema;
|
||||
};
|
||||
|
||||
struct tool_definition {
|
||||
std::string name;
|
||||
std::vector<tool_argument> arguments;
|
||||
json schema;
|
||||
};
|
||||
|
||||
// Test fictitious model output that emits arguments as JSON.
|
||||
static void test_example_native(testing & t) {
|
||||
struct test_case {
|
||||
// Parameters
|
||||
std::string name;
|
||||
json tools;
|
||||
common_chat_tool_choice tool_choice;
|
||||
common_reasoning_format reasoning_format;
|
||||
json json_schema;
|
||||
bool parallel_tool_calls;
|
||||
bool thinking_forced_open;
|
||||
std::string input;
|
||||
|
||||
// Expect
|
||||
std::string expect_reasoning;
|
||||
std::string expect_content;
|
||||
std::vector<common_chat_tool_call> expect_tool_calls;
|
||||
};
|
||||
|
||||
auto build_parser = [](const test_case & tc) {
|
||||
return build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
|
||||
auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE);
|
||||
auto reasoning = p.eps();
|
||||
if (tc.thinking_forced_open) {
|
||||
// If thinking is forced open, expect a closing tag
|
||||
reasoning = p.reasoning(p.until("</think>")) + "</think>" + p.space();
|
||||
} else {
|
||||
// Otherwise, optionally accept thinking wrapped in tags
|
||||
reasoning = p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>" + p.space());
|
||||
}
|
||||
|
||||
// tool calling parser
|
||||
if (tc.tools.is_array() && !tc.tools.empty()) {
|
||||
auto tools = p.choice();
|
||||
for (const auto & tool : tc.tools) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
const auto & schema = function.at("parameters");
|
||||
|
||||
auto tool_name = p.json_member("name", "\"" + p.tool_name(p.literal(name)) + "\"");
|
||||
auto tool_args = p.json_member("arguments", p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
|
||||
|
||||
tools |= p.rule("tool-" + name, p.tool_open(p.literal("{")) << tool_name << "," << tool_args << "}");
|
||||
};
|
||||
|
||||
auto parallel_calls = p.eps();
|
||||
if (tc.parallel_tool_calls) {
|
||||
parallel_calls = p.zero_or_more("," << tools);
|
||||
}
|
||||
|
||||
auto tool_call = p.trigger_rule("tool-call",
|
||||
p.sequence({
|
||||
p.literal("<tool_call>["),
|
||||
tools,
|
||||
parallel_calls,
|
||||
p.literal("]</tool_call>")
|
||||
})
|
||||
);
|
||||
|
||||
return p.sequence({
|
||||
(reasoning_in_content ? p.eps() : reasoning),
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.optional(p.space() + tool_call),
|
||||
p.space(),
|
||||
p.end()
|
||||
});
|
||||
}
|
||||
|
||||
// response_format parser
|
||||
if (tc.json_schema.is_object() && !tc.json_schema.empty()) {
|
||||
return p.sequence({
|
||||
(reasoning_in_content ? p.eps() : reasoning),
|
||||
p.content(p.schema(p.json(), "response-output", tc.json_schema)),
|
||||
p.space(),
|
||||
p.end()
|
||||
});
|
||||
}
|
||||
|
||||
// Content-only parser
|
||||
return p.sequence({
|
||||
(reasoning_in_content ? p.eps() : reasoning),
|
||||
p.content(p.rest()),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
std::vector<test_case> test_cases = std::vector<test_case>{
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = false",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
/* .input = */ (
|
||||
"<think>The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "The user said hello, I must say hello back",
|
||||
/* .expect_content = */ "Hello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = false and no reasoning",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
/* .input = */ (
|
||||
"Hello"
|
||||
),
|
||||
/* .expect_reasoning = */ "",
|
||||
/* .expect_content = */ "Hello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = false and reasoning_format = none",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"<think>The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "",
|
||||
/* .expect_content = */ "<think>The user said hello, I must say hello back</think>\nHello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = true",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "The user said hello, I must say hello back",
|
||||
/* .expect_content = */ "Hello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = true and reasoning_format = none",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "",
|
||||
/* .expect_content = */ "The user said hello, I must say hello back</think>\nHello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "tools with tool_choice = auto and no parallel_tool_calls",
|
||||
/* .tools = */ create_tools(),
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"I must get the weather in New York</think>\n"
|
||||
"<tool_call>["
|
||||
R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})"
|
||||
"]</tool_call>"
|
||||
),
|
||||
/* .expect_reasoning = */ "I must get the weather in New York",
|
||||
/* .expect_content = */ "",
|
||||
/* .expect_tool_calls = */ {{
|
||||
/* .name = */ "get_current_weather",
|
||||
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})",
|
||||
/* .id = */ "",
|
||||
}},
|
||||
},
|
||||
{
|
||||
/* .name = */ "tools with tool_choice = auto and parallel_tool_calls",
|
||||
/* .tools = */ create_tools(),
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ true,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"I must get the weather in New York and San Francisco and a 3 day forecast of each.</think>\nLet me search that for you."
|
||||
"<tool_call>["
|
||||
R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})"
|
||||
", "
|
||||
R"({"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}})"
|
||||
", "
|
||||
R"({"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}})"
|
||||
", "
|
||||
R"({"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}})"
|
||||
"]</tool_call>"
|
||||
),
|
||||
/* .expect_reasoning = */ "I must get the weather in New York and San Francisco and a 3 day forecast of each.",
|
||||
/* .expect_content = */ "Let me search that for you.",
|
||||
/* .expect_tool_calls = */ {{
|
||||
/* .name = */ "get_current_weather",
|
||||
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})",
|
||||
/* .id = */ "",
|
||||
}, {
|
||||
/* .name = */ "get_current_weather",
|
||||
/* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit"})",
|
||||
/* .id = */ "",
|
||||
}, {
|
||||
/* .name = */ "get_forecast",
|
||||
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit", "days": 3})",
|
||||
/* .id = */ "",
|
||||
}, {
|
||||
/* .name = */ "get_forecast",
|
||||
/* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3})",
|
||||
/* .id = */ "",
|
||||
}},
|
||||
},
|
||||
{
|
||||
/* .name = */ "response_format with thinking_forced_open = true",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"invoice_number", {{"type", "string"}}},
|
||||
{"amount", {{"type", "number"}}},
|
||||
{"due_date", {{"type", "string"}}}
|
||||
}},
|
||||
{"required", {"invoice_number", "amount", "due_date"}}
|
||||
},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"I must produce the invoice in the requested format</think>\n"
|
||||
R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})"
|
||||
),
|
||||
/* .expect_reasoning = */ "I must produce the invoice in the requested format",
|
||||
/* .expect_content = */ R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
};
|
||||
|
||||
for (const auto & tc : test_cases) {
|
||||
t.test(tc.name, [&](testing & t) {
|
||||
auto parser = build_parser(tc);
|
||||
auto lazy = !tc.tools.empty() && tc.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
auto grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
for (auto const & def : tc.tools) {
|
||||
auto function = def.at("function");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
};
|
||||
parser.build_grammar(builder, lazy);
|
||||
});
|
||||
|
||||
t.log("Grammar:");
|
||||
for (auto const & line : string_split(grammar, "\n")) {
|
||||
t.log(line);
|
||||
}
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_true("success", result.success());
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_native_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
t.assert_equal("content equal", tc.expect_content, msg.content);
|
||||
t.assert_equal("reasoning equal", tc.expect_reasoning, msg.reasoning_content);
|
||||
t.assert_equal("number of tool calls", tc.expect_tool_calls.size(), msg.tool_calls.size());
|
||||
for (auto i = 0u; i < std::min(tc.expect_tool_calls.size(), msg.tool_calls.size()); i++) {
|
||||
t.assert_equal("tool name", tc.expect_tool_calls[i].name, msg.tool_calls[i].name);
|
||||
t.assert_equal("tool args", tc.expect_tool_calls[i].arguments, msg.tool_calls[i].arguments);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void test_example_qwen3_coder(testing & t) {
|
||||
auto tools = create_tools();
|
||||
auto parser = build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) {
|
||||
auto content = p.rule("content", p.content(p.until("<tool_call>")));
|
||||
|
||||
std::vector<common_peg_parser> tool_parsers;
|
||||
for (auto const & def : tools) {
|
||||
auto function = def.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
auto properties = parameters.at("properties");
|
||||
|
||||
std::set<std::string> required_properties;
|
||||
if (function.contains("required")) {
|
||||
function.at("required").get_to(required_properties);
|
||||
}
|
||||
|
||||
std::vector<common_peg_parser> arg_parsers;
|
||||
for (const auto & [param_name, param_schema] : properties.items()) {
|
||||
bool is_required = required_properties.find(param_name) != required_properties.end();
|
||||
auto type = param_schema.value("type", "object");
|
||||
|
||||
auto arg = p.tool_arg(p.sequence({
|
||||
p.tool_arg_open("<parameter=" + p.tool_arg_name(p.literal(param_name)) + ">"),
|
||||
(type == "string" ?
|
||||
p.tool_arg_string_value(
|
||||
p.schema(
|
||||
p.until_one_of({
|
||||
"</parameter>\n<parameter=",
|
||||
"</parameter>\n</function>"
|
||||
}),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema,
|
||||
true
|
||||
)
|
||||
) : p.tool_arg_json_value(
|
||||
p.schema(
|
||||
p.json(),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema
|
||||
)
|
||||
)
|
||||
),
|
||||
p.tool_arg_close(
|
||||
"</parameter>\n" +
|
||||
p.peek(p.literal("<parameter=") | p.literal("</function>"))
|
||||
)
|
||||
}));
|
||||
|
||||
arg_parsers.push_back(is_required ?
|
||||
p.rule("tool-" + name + "-arg-" + param_name, arg) :
|
||||
p.optional(p.rule("tool-" + name + "-arg-" + param_name, arg)));
|
||||
}
|
||||
|
||||
tool_parsers.push_back(p.rule("tool-" + name,
|
||||
p.tool_open("<function=" + p.tool_name(p.literal(name)) + ">")
|
||||
<< p.sequence(arg_parsers)
|
||||
<< p.tool_close(p.literal("</function>"))
|
||||
));
|
||||
};
|
||||
|
||||
auto tool_call = p.trigger_rule("tool-call",
|
||||
"<tool_call>"
|
||||
<< p.choice(tool_parsers)
|
||||
<< "</tool_call>"
|
||||
);
|
||||
|
||||
return content + p.zero_or_more(p.space() + tool_call) + p.end();
|
||||
});
|
||||
|
||||
auto grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
for (auto const & def : tools) {
|
||||
auto function = def.at("function");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
};
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
t.log("Grammar:");
|
||||
for (auto const & line : string_split(grammar, "\n")) {
|
||||
t.log(line);
|
||||
}
|
||||
|
||||
t.test("incremental parsing", [&](testing &t) {
|
||||
std::string input =
|
||||
"Let me search the knowledge base for cat pictures."
|
||||
"<tool_call>\n"
|
||||
"<function=search_knowledge_base>\n"
|
||||
"<parameter=query>cat pictures</parameter>\n"
|
||||
"<parameter=category>general</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
|
||||
std::vector<std::string> tokens = simple_tokenize(input);
|
||||
|
||||
common_chat_msg prev;
|
||||
for (auto it = tokens.begin(); it != tokens.end(); it++) {
|
||||
std::string in = std::accumulate(tokens.begin(), it + 1, std::string());
|
||||
|
||||
common_peg_parse_context ctx(in, it + 1 < tokens.end());
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
if (!t.assert_equal("not fail", false, result.fail())) {
|
||||
t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end));
|
||||
}
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_constructed_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
//t.log("Input: " + input);
|
||||
t.log("===========================================");
|
||||
t.log("Iteration " + std::to_string(in.size()));
|
||||
t.log("Reasoning: " + msg.reasoning_content);
|
||||
t.log("Content : " + msg.content);
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
t.log("Tool name: " + tc.name);
|
||||
t.log("Tool args: " + tc.arguments);
|
||||
}
|
||||
|
||||
try {
|
||||
// This shouldn't emit any runtime errors
|
||||
auto diffs = common_chat_msg_diff::compute_diffs(prev, msg);
|
||||
} catch(const std::exception & e) {
|
||||
t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end));
|
||||
t.assert_true(std::string("failed with ") + e.what(), false);
|
||||
}
|
||||
|
||||
prev = msg;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void test_command7_parser_compare(testing & t) {
|
||||
auto parser = build_chat_peg_native_parser([](common_chat_peg_native_builder & p) {
|
||||
auto thinking = p.reasoning_block(
|
||||
"<|START_THINKING|>" << p.reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>");
|
||||
|
||||
auto response = "<|START_RESPONSE|>" << p.content(p.until("<|END_RESPONSE|>")) << "<|END_RESPONSE|>";
|
||||
|
||||
auto tool_call_id = p.atomic("\"tool_call_id\"" << (":" << ("\"" + p.tool_id(p.json_string_content()) + "\"")));
|
||||
auto tool_call_name = p.atomic("\"tool_name\"" << (":" << ("\"" + p.tool_name(p.json_string_content()) + "\"")));
|
||||
auto tool_call_args = "\"parameters\"" << (":" << p.tool_args(p.json()));
|
||||
|
||||
auto tool_call_fields = p.rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args);
|
||||
auto tool_call = p.rule("tool-call", p.tool(
|
||||
p.tool_open(p.literal("{"))
|
||||
<< tool_call_fields
|
||||
<< p.zero_or_more( p.literal(",") << tool_call_fields)
|
||||
<< p.tool_close(p.literal("}"))
|
||||
));
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
"<|START_ACTION|>"
|
||||
<< ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]")
|
||||
<< "<|END_ACTION|>");
|
||||
|
||||
return p.optional(thinking) << (tool_calls | response) + p.end();
|
||||
});
|
||||
|
||||
auto test_current = [&](const common_peg_arena & p, const std::string & input, bool is_partial, bool print_results) {
|
||||
common_peg_parse_context ctx(input, is_partial);
|
||||
auto result = p.parse(ctx);
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_native_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
if (print_results) {
|
||||
std::cout << "== Parsed (new) ==\n";
|
||||
std::cout << "=== Reasoning ===\n";
|
||||
std::cout << msg.reasoning_content << "\n";
|
||||
std::cout << "\n\n=== Content ===\n";
|
||||
std::cout << msg.content << "\n";
|
||||
std::cout << "\n\n=== Tool Calls ===\n";
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
std::cout << "id: " << tc.id << "\n";
|
||||
std::cout << "name: " << tc.name << "\n";
|
||||
std::cout << "args: " << tc.arguments << "\n";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) {
|
||||
// Original common_chat_combinator_parser taken from chat.cpp
|
||||
common_chat_msg_parser builder(
|
||||
input,
|
||||
/* .is_partial = */ need_more_input,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_GENERIC,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
}
|
||||
);
|
||||
|
||||
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
|
||||
|
||||
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
|
||||
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
|
||||
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
|
||||
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
|
||||
|
||||
if (auto res = builder.try_find_regex(start_action_regex)) {
|
||||
// If we didn't extract thoughts, prelude includes them.
|
||||
auto tool_calls = builder.consume_json_with_dumped_args({ { "parameters" } });
|
||||
for (const auto & tool_call : tool_calls.value) {
|
||||
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
|
||||
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
|
||||
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
|
||||
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
if (tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_regex(end_action_regex);
|
||||
} else if (auto res = builder.try_find_regex(start_response_regex)) {
|
||||
if (!builder.try_find_regex(end_response_regex)) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
throw common_chat_msg_partial_exception(end_response_regex.str());
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
if (print_results) {
|
||||
std::cout << "== Parsed (legacy) ==\n";
|
||||
std::cout << "=== Reasoning ===\n";
|
||||
std::cout << builder.result().reasoning_content << "\n";
|
||||
std::cout << "\n\n=== Content ===\n";
|
||||
std::cout << builder.result().content << "\n";
|
||||
std::cout << "\n\n=== Tool Calls ===\n";
|
||||
for (const auto & tc : builder.result().tool_calls) {
|
||||
std::cout << "id: " << tc.id << "\n";
|
||||
std::cout << "name: " << tc.name << "\n";
|
||||
std::cout << "args: " << tc.arguments << "\n";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::string reasoning = "To plan an effective trip to Japan that includes both historical sites and modern attractions within a "
|
||||
"budget of $4000 for a two-week stay, we need to:\n\n"
|
||||
"1. Identify key historical sites and modern attractions in Japan.\n"
|
||||
"2. Find affordable accommodation options that provide a balance between comfort and cost.\n"
|
||||
"3. Determine the best modes of transportation for getting around Japan.\n"
|
||||
"4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without "
|
||||
"overspending.\n"
|
||||
"5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees "
|
||||
"to attractions.";
|
||||
|
||||
std::vector<std::tuple<std::string, std::string, nlohmann::json>> tool_calls = {{
|
||||
"call_0",
|
||||
"plan_trip",
|
||||
nlohmann::json::parse(R"({
|
||||
"destination": "Japan",
|
||||
"duration": 14,
|
||||
"budget": 4000,
|
||||
"interests": ["historical sites", "modern attractions"],
|
||||
"accommodation_preferences": "affordable",
|
||||
"transportation_preferences": "efficient",
|
||||
"meal_preferences": "local cuisine"
|
||||
})")
|
||||
}};
|
||||
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
// Build tokens
|
||||
if (!reasoning.empty()) {
|
||||
auto tokenized = simple_tokenize(reasoning);
|
||||
tokens.emplace_back("<|START_THINKING|>");
|
||||
tokens.insert(tokens.end(), tokenized.begin(), tokenized.end());
|
||||
tokens.emplace_back("<|END_THINKING|>");
|
||||
}
|
||||
|
||||
if (!tool_calls.empty()) {
|
||||
tokens.emplace_back("<|START_ACTION|>");
|
||||
|
||||
auto json = nlohmann::json::array();
|
||||
for (const auto & tc : tool_calls) {
|
||||
auto tc_json = nlohmann::json::object();
|
||||
tc_json["tool_call_id"] = std::get<0>(tc);
|
||||
tc_json["tool_name"] = std::get<1>(tc);
|
||||
tc_json["parameters"] = std::get<2>(tc);
|
||||
json.push_back(tc_json);
|
||||
}
|
||||
|
||||
auto tokenized = simple_tokenize(json.dump(-1, ' ', true));
|
||||
tokens.insert(tokens.end(), tokenized.begin(), tokenized.end());
|
||||
|
||||
tokens.emplace_back("<|END_ACTION|>");
|
||||
}
|
||||
|
||||
std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string());
|
||||
|
||||
// Run tests
|
||||
t.test("legacy_parse", [&](testing & /* t */) {
|
||||
test_legacy(input, false, false);
|
||||
});
|
||||
|
||||
t.test("current_parse", [&](testing & /* t */) {
|
||||
test_current(parser, input, false, false);
|
||||
});
|
||||
|
||||
// Run benchmarks
|
||||
t.bench("legacy_parse_benchmark complete", [&]() {
|
||||
test_legacy(input, false, false);
|
||||
});
|
||||
|
||||
t.bench("legacy_parse_benchmark incremental", [&]() {
|
||||
std::string in;
|
||||
for (auto i = 0u; i < tokens.size(); i++) {
|
||||
in += tokens[i];
|
||||
|
||||
try {
|
||||
test_legacy(in, i + 1 < tokens.size(), false);
|
||||
} catch (common_chat_msg_partial_exception & /* e */) {
|
||||
// Do nothing, this is expected
|
||||
}
|
||||
}
|
||||
}, 20);
|
||||
|
||||
t.bench("current_parse_benchmark complete", [&]() {
|
||||
test_current(parser, input, false, false);
|
||||
}, 100);
|
||||
|
||||
t.bench("current_parse_benchmark incremental", [&]() {
|
||||
std::string in;
|
||||
for (auto i = 0u; i < tokens.size(); i++) {
|
||||
in += tokens[i];
|
||||
test_current(parser, in, i + 1 < tokens.size(), false);
|
||||
}
|
||||
}, 20);
|
||||
}
|
||||
25
tests/test-peg-parser.cpp
Normal file
25
tests/test-peg-parser.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
|
||||
#include "peg-parser/tests.h"
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
testing t(std::cout);
|
||||
if (argc >= 2) {
|
||||
t.set_filter(argv[1]);
|
||||
}
|
||||
|
||||
const char * verbose = getenv("LLAMA_TEST_VERBOSE");
|
||||
if (verbose) {
|
||||
t.verbose = std::string(verbose) == "1";
|
||||
}
|
||||
|
||||
t.test("basic", test_basic);
|
||||
t.test("unicode", test_unicode);
|
||||
t.test("json", test_json_parser);
|
||||
t.test("gbnf", test_gbnf_generation);
|
||||
t.test("serialization", test_json_serialization);
|
||||
|
||||
return t.summary();
|
||||
}
|
||||
@@ -2,11 +2,6 @@ set(TARGET llama-server)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
if (MINGW)
|
||||
# fix: https://github.com/ggml-org/llama.cpp/actions/runs/9651004652/job/26617901362?pr=8006
|
||||
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
|
||||
endif()
|
||||
|
||||
if (NOT LLAMA_HTTPLIB)
|
||||
message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF")
|
||||
endif()
|
||||
|
||||
Binary file not shown.
@@ -791,7 +791,7 @@ static void handle_media(
|
||||
SRV_INF("downloading image from '%s'\n", url.c_str());
|
||||
auto res = common_remote_get_content(url, params);
|
||||
if (200 <= res.first && res.first < 300) {
|
||||
SRV_INF("downloaded %ld bytes\n", res.second.size());
|
||||
SRV_INF("downloaded %zu bytes\n", res.second.size());
|
||||
raw_buffer data;
|
||||
data.insert(data.end(), res.second.begin(), res.second.end());
|
||||
out_files.push_back(data);
|
||||
@@ -1045,6 +1045,9 @@ json oaicompat_chat_params_parse(
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
if (!chat_params.parser.empty()) {
|
||||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
|
||||
@@ -101,8 +101,6 @@ struct server_slot {
|
||||
std::string generated_text;
|
||||
llama_tokens generated_tokens;
|
||||
|
||||
common_chat_msg chat_msg;
|
||||
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
bool has_next_token = true;
|
||||
@@ -153,9 +151,6 @@ struct server_slot {
|
||||
|
||||
llama_token sampled;
|
||||
|
||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
|
||||
@@ -183,13 +178,10 @@ struct server_slot {
|
||||
stop = STOP_TYPE_NONE;
|
||||
stopping_word = "";
|
||||
n_sent_text = 0;
|
||||
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
generated_tokens.clear();
|
||||
generated_token_probs.clear();
|
||||
chat_msg = {};
|
||||
json_schema = json();
|
||||
generated_tool_call_ids.clear();
|
||||
|
||||
// clear speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
@@ -302,23 +294,6 @@ struct server_slot {
|
||||
return timings;
|
||||
}
|
||||
|
||||
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
auto previous_msg = chat_msg;
|
||||
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
||||
auto new_msg = common_chat_parse(
|
||||
generated_text,
|
||||
/* is_partial= */ stop != STOP_TYPE_EOS,
|
||||
task->params.oaicompat_chat_syntax);
|
||||
if (!new_msg.empty()) {
|
||||
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
|
||||
chat_msg = new_msg;
|
||||
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
|
||||
}
|
||||
return chat_msg;
|
||||
}
|
||||
|
||||
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
@@ -1284,8 +1259,6 @@ struct server_context_impl {
|
||||
} else {
|
||||
res->content = tkn.text_to_send;
|
||||
res->tokens = { tkn.tok };
|
||||
|
||||
slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
}
|
||||
|
||||
res->n_decoded = slot.n_decoded;
|
||||
@@ -1317,8 +1290,14 @@ struct server_context_impl {
|
||||
res->id_slot = slot.id;
|
||||
|
||||
res->index = slot.task->index;
|
||||
res->content = slot.generated_text;
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
// in stream mode, content and tokens are already in last partial chunk
|
||||
if (slot.task->params.stream) {
|
||||
res->content = "";
|
||||
res->tokens = llama_tokens{};
|
||||
} else {
|
||||
res->content = std::move(slot.generated_text);
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
}
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = slot.task->tokens.detokenize(ctx, true);
|
||||
res->response_fields = std::move(slot.task->params.response_fields);
|
||||
@@ -1338,7 +1317,6 @@ struct server_context_impl {
|
||||
res->res_type = slot.task->params.res_type;
|
||||
res->oaicompat_model = slot.task->params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
|
||||
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.task->params.sampling.n_probs > 0) {
|
||||
@@ -2596,6 +2574,9 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
try {
|
||||
std::vector<server_task> tasks;
|
||||
|
||||
// tracking generation state and partial tool calls
|
||||
std::vector<task_result_state> states;
|
||||
|
||||
const auto & prompt = data.at("prompt");
|
||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
@@ -2611,6 +2592,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
tasks.reserve(inputs.size());
|
||||
states.reserve(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
@@ -2628,10 +2610,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
task.params.res_type = res_type;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.oaicompat_model = ctx_server.model_name;
|
||||
states.push_back(task.params.oaicompat_chat_syntax);
|
||||
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
rd.set_states(std::move(states));
|
||||
rd.post_tasks(std::move(tasks));
|
||||
} catch (const std::exception & e) {
|
||||
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
@@ -2657,7 +2641,6 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
// if single request, return single object instead of array
|
||||
res->ok(arr.size() == 1 ? arr[0] : arr);
|
||||
}
|
||||
|
||||
} else {
|
||||
// in streaming mode, the first error must be treated as non-stream response
|
||||
// this is to match the OAI API behavior
|
||||
@@ -2676,76 +2659,92 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
}
|
||||
|
||||
// next responses are streamed
|
||||
// to be sent immediately
|
||||
json first_result_json = first_result->to_json();
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
res->data = format_anthropic_sse(first_result->to_json());
|
||||
res->data = format_anthropic_sse(first_result_json);
|
||||
} else {
|
||||
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
|
||||
res->data = format_oai_sse(first_result_json);
|
||||
}
|
||||
res->status = 200;
|
||||
res->content_type = "text/event-stream";
|
||||
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
|
||||
if (should_stop()) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
|
||||
if (!res_this->data.empty()) {
|
||||
// flush the first chunk
|
||||
output = std::move(res_this->data);
|
||||
res_this->data.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
server_response_reader & rd = res_this->rd;
|
||||
|
||||
// check if there is more data
|
||||
if (!rd.has_next()) {
|
||||
static auto format_error = [](task_response_type res_type, const json & res_json) {
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
// Anthropic doesn't send [DONE], message_stop was already sent
|
||||
output = "";
|
||||
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
|
||||
output = "data: [DONE]\n\n";
|
||||
} else {
|
||||
output = "";
|
||||
}
|
||||
SRV_DBG("%s", "all results received, terminating stream\n");
|
||||
return false; // no more data, terminate
|
||||
}
|
||||
|
||||
// receive subsequent results
|
||||
auto result = rd.next(should_stop);
|
||||
if (result == nullptr) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
|
||||
// send the results
|
||||
json res_json = result->to_json();
|
||||
if (result->is_error()) {
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
output = format_anthropic_sse({
|
||||
return format_anthropic_sse({
|
||||
{"event", "error"},
|
||||
{"data", res_json},
|
||||
});
|
||||
} else {
|
||||
output = format_oai_sse(json {{ "error", res_json }});
|
||||
return format_oai_sse(json {{ "error", res_json }});
|
||||
}
|
||||
SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
||||
return false; // terminate on error
|
||||
} else {
|
||||
GGML_ASSERT(
|
||||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||
);
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
output = format_anthropic_sse(res_json);
|
||||
} else {
|
||||
output = format_oai_sse(res_json);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// has next data, continue
|
||||
return true;
|
||||
try {
|
||||
if (should_stop()) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
|
||||
if (!res_this->data.empty()) {
|
||||
// flush the first chunk
|
||||
output = std::move(res_this->data);
|
||||
res_this->data.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
server_response_reader & rd = res_this->rd;
|
||||
|
||||
// check if there is more data
|
||||
if (!rd.has_next()) {
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
// Anthropic doesn't send [DONE], message_stop was already sent
|
||||
output = "";
|
||||
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
|
||||
output = "data: [DONE]\n\n";
|
||||
} else {
|
||||
output = "";
|
||||
}
|
||||
SRV_DBG("%s", "all results received, terminating stream\n");
|
||||
return false; // no more data, terminate
|
||||
}
|
||||
|
||||
// receive subsequent results
|
||||
auto result = rd.next(should_stop);
|
||||
if (result == nullptr) {
|
||||
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
|
||||
return false; // should_stop condition met
|
||||
}
|
||||
|
||||
// send the results
|
||||
if (result->is_error()) {
|
||||
json res_json = result->to_json();
|
||||
output = format_error(res_type, res_json);
|
||||
SRV_DBG("%s", "error received during streaming, terminating stream\n");
|
||||
return false; // terminate on error
|
||||
} else {
|
||||
GGML_ASSERT(
|
||||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||
);
|
||||
json res_json = result->to_json();
|
||||
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
|
||||
output = format_anthropic_sse(res_json);
|
||||
} else {
|
||||
output = format_oai_sse(res_json);
|
||||
}
|
||||
}
|
||||
|
||||
// has next data, continue
|
||||
return true;
|
||||
|
||||
} catch (const std::exception & e) {
|
||||
json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER);
|
||||
output = format_error(res_type, error_json);
|
||||
|
||||
// terminate on exception
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -900,6 +900,7 @@ static bool should_strip_proxy_header(const std::string & header_name) {
|
||||
// Headers that get duplicated when router forwards child responses
|
||||
if (header_name == "server" ||
|
||||
header_name == "transfer-encoding" ||
|
||||
header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
|
||||
header_name == "keep-alive") {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -271,6 +271,10 @@ void server_response::terminate() {
|
||||
// server_response_reader
|
||||
//
|
||||
|
||||
void server_response_reader::set_states(std::vector<task_result_state> && states) {
|
||||
this->states = std::move(states);
|
||||
}
|
||||
|
||||
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
|
||||
id_tasks = server_task::get_list_id(tasks);
|
||||
queue_results.add_waiting_tasks(tasks);
|
||||
@@ -298,6 +302,12 @@ server_task_result_ptr server_response_reader::next(const std::function<bool()>
|
||||
SRV_DBG("%s", "received error result, stopping further processing\n");
|
||||
return result;
|
||||
}
|
||||
if (!states.empty()) {
|
||||
// update the generation state if needed
|
||||
size_t idx = result->get_index();
|
||||
GGML_ASSERT(idx < states.size());
|
||||
result->update(states[idx]);
|
||||
}
|
||||
if (result->is_stop()) {
|
||||
received_count++;
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include <mutex>
|
||||
#include <unordered_set>
|
||||
|
||||
// struct for managing server tasks
|
||||
// in most cases, use server_response_reader to post new tasks and retrieve results
|
||||
struct server_queue {
|
||||
private:
|
||||
int id = 0;
|
||||
@@ -67,6 +69,8 @@ private:
|
||||
void cleanup_pending_task(int id_target);
|
||||
};
|
||||
|
||||
// struct for managing server responses
|
||||
// in most cases, use server_response_reader to retrieve results
|
||||
struct server_response {
|
||||
private:
|
||||
bool running = true;
|
||||
@@ -120,6 +124,10 @@ struct server_response_reader {
|
||||
bool cancelled = false;
|
||||
int polling_interval_seconds;
|
||||
|
||||
// tracking generation state and partial tool calls
|
||||
// only used by streaming completions
|
||||
std::vector<task_result_state> states;
|
||||
|
||||
// should_stop function will be called each polling_interval_seconds
|
||||
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
|
||||
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
|
||||
@@ -127,6 +135,7 @@ struct server_response_reader {
|
||||
stop();
|
||||
}
|
||||
|
||||
void set_states(std::vector<task_result_state> && states);
|
||||
void post_tasks(std::vector<server_task> && tasks);
|
||||
bool has_next() const;
|
||||
|
||||
|
||||
@@ -297,6 +297,9 @@ task_params server_task::params_from_json_cmpl(
|
||||
params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
||||
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
||||
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
||||
if (data.contains("chat_parser")) {
|
||||
params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get<std::string>());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
@@ -562,6 +565,7 @@ std::vector<unsigned char> completion_token_output::str_to_bytes(const std::stri
|
||||
// server_task_result_cmpl_final
|
||||
//
|
||||
json server_task_result_cmpl_final::to_json() {
|
||||
GGML_ASSERT(is_updated && "update() must be called before to_json()");
|
||||
switch (res_type) {
|
||||
case TASK_RESPONSE_TYPE_NONE:
|
||||
return to_json_non_oaicompat();
|
||||
@@ -579,8 +583,8 @@ json server_task_result_cmpl_final::to_json() {
|
||||
json server_task_result_cmpl_final::to_json_non_oaicompat() {
|
||||
json res = json {
|
||||
{"index", index},
|
||||
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"tokens", stream ? llama_tokens {} : tokens},
|
||||
{"content", content},
|
||||
{"tokens", tokens},
|
||||
{"id_slot", id_slot},
|
||||
{"stop", true},
|
||||
{"model", oaicompat_model},
|
||||
@@ -616,7 +620,7 @@ json server_task_result_cmpl_final::to_json_oaicompat() {
|
||||
json res = json {
|
||||
{"choices", json::array({
|
||||
json{
|
||||
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"text", content},
|
||||
{"index", index},
|
||||
{"logprobs", logprobs},
|
||||
{"finish_reason", finish_reason},
|
||||
@@ -697,6 +701,25 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat() {
|
||||
return res;
|
||||
}
|
||||
|
||||
common_chat_msg task_result_state::update_chat_msg(
|
||||
const std::string & text_added,
|
||||
bool is_partial,
|
||||
std::vector<common_chat_msg_diff> & diffs) {
|
||||
generated_text += text_added;
|
||||
auto msg_prv_copy = chat_msg;
|
||||
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
||||
auto new_msg = common_chat_parse(
|
||||
generated_text,
|
||||
is_partial,
|
||||
oaicompat_chat_syntax);
|
||||
if (!new_msg.empty()) {
|
||||
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
|
||||
chat_msg = new_msg;
|
||||
diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg);
|
||||
}
|
||||
return chat_msg;
|
||||
}
|
||||
|
||||
json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
|
||||
std::time_t t = std::time(0);
|
||||
std::string finish_reason = "length";
|
||||
@@ -953,6 +976,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
||||
// server_task_result_cmpl_partial
|
||||
//
|
||||
json server_task_result_cmpl_partial::to_json() {
|
||||
GGML_ASSERT(is_updated && "update() must be called before to_json()");
|
||||
switch (res_type) {
|
||||
case TASK_RESPONSE_TYPE_NONE:
|
||||
return to_json_non_oaicompat();
|
||||
|
||||
@@ -161,6 +161,25 @@ struct result_prompt_progress {
|
||||
json to_json() const;
|
||||
};
|
||||
|
||||
// struct for tracking the state of a task (e.g., for streaming)
|
||||
struct task_result_state {
|
||||
// tracking diffs for partial tool calls
|
||||
std::vector<common_chat_msg_diff> diffs;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
common_chat_msg chat_msg;
|
||||
std::string generated_text; // append new chunks of generated text here
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
|
||||
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
|
||||
|
||||
// parse partial tool calls and update the internal state
|
||||
common_chat_msg update_chat_msg(
|
||||
const std::string & text_added,
|
||||
bool is_partial,
|
||||
std::vector<common_chat_msg_diff> & diffs);
|
||||
};
|
||||
|
||||
struct server_task_result {
|
||||
int id = -1;
|
||||
int id_slot = -1;
|
||||
@@ -175,6 +194,9 @@ struct server_task_result {
|
||||
virtual int get_index() {
|
||||
return -1;
|
||||
}
|
||||
virtual void update(task_result_state &) {
|
||||
// only used by server_task_result_cmpl_*
|
||||
}
|
||||
virtual json to_json() = 0;
|
||||
virtual ~server_task_result() = default;
|
||||
};
|
||||
@@ -233,9 +255,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_msg;
|
||||
common_chat_msg oaicompat_msg; // to be populated by update()
|
||||
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||
bool is_updated = false;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
@@ -247,6 +270,11 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
|
||||
virtual json to_json() override;
|
||||
|
||||
virtual void update(task_result_state & state) override {
|
||||
is_updated = true;
|
||||
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat();
|
||||
|
||||
json to_json_oaicompat();
|
||||
@@ -280,7 +308,8 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
|
||||
bool is_updated = false;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
@@ -292,6 +321,11 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
|
||||
virtual json to_json() override;
|
||||
|
||||
virtual void update(task_result_state & state) override {
|
||||
is_updated = true;
|
||||
state.update_chat_msg(content, true, oaicompat_msg_diffs);
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat();
|
||||
|
||||
json to_json_oaicompat();
|
||||
|
||||
@@ -65,6 +65,7 @@ def test_server_slots():
|
||||
|
||||
def test_load_split_model():
|
||||
global server
|
||||
server.offline = False
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
|
||||
server.model_alias = "tinyllama-split"
|
||||
|
||||
@@ -17,7 +17,6 @@ def create_server():
|
||||
]
|
||||
)
|
||||
def test_router_chat_completion_stream(model: str, success: bool):
|
||||
# TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server
|
||||
global server
|
||||
server.start()
|
||||
content = ""
|
||||
@@ -48,3 +47,148 @@ def test_router_chat_completion_stream(model: str, success: bool):
|
||||
else:
|
||||
assert ex is not None
|
||||
assert content == ""
|
||||
|
||||
|
||||
def _get_model_status(model_id: str) -> str:
|
||||
res = server.make_request("GET", "/models")
|
||||
assert res.status_code == 200
|
||||
for item in res.body.get("data", []):
|
||||
if item.get("id") == model_id or item.get("model") == model_id:
|
||||
return item["status"]["value"]
|
||||
raise AssertionError(f"Model {model_id} not found in /models response")
|
||||
|
||||
|
||||
def _wait_for_model_status(model_id: str, desired: set[str], timeout: int = 60) -> str:
|
||||
deadline = time.time() + timeout
|
||||
last_status = None
|
||||
while time.time() < deadline:
|
||||
last_status = _get_model_status(model_id)
|
||||
if last_status in desired:
|
||||
return last_status
|
||||
time.sleep(1)
|
||||
raise AssertionError(
|
||||
f"Timed out waiting for {model_id} to reach {desired}, last status: {last_status}"
|
||||
)
|
||||
|
||||
|
||||
def _load_model_and_wait(
|
||||
model_id: str, timeout: int = 60, headers: dict | None = None
|
||||
) -> None:
|
||||
load_res = server.make_request(
|
||||
"POST", "/models/load", data={"model": model_id}, headers=headers
|
||||
)
|
||||
assert load_res.status_code == 200
|
||||
assert isinstance(load_res.body, dict)
|
||||
assert load_res.body.get("success") is True
|
||||
_wait_for_model_status(model_id, {"loaded"}, timeout=timeout)
|
||||
|
||||
|
||||
def test_router_unload_model():
|
||||
global server
|
||||
server.start()
|
||||
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||
|
||||
_load_model_and_wait(model_id)
|
||||
|
||||
unload_res = server.make_request("POST", "/models/unload", data={"model": model_id})
|
||||
assert unload_res.status_code == 200
|
||||
assert unload_res.body.get("success") is True
|
||||
_wait_for_model_status(model_id, {"unloaded"})
|
||||
|
||||
|
||||
def test_router_models_max_evicts_lru():
|
||||
global server
|
||||
server.models_max = 2
|
||||
server.start()
|
||||
|
||||
candidate_models = [
|
||||
"ggml-org/tinygemma3-GGUF:Q8_0",
|
||||
"ggml-org/test-model-stories260K",
|
||||
"ggml-org/test-model-stories260K-infill",
|
||||
]
|
||||
|
||||
# Load only the first 2 models to fill the cache
|
||||
first, second, third = candidate_models[:3]
|
||||
|
||||
_load_model_and_wait(first, timeout=120)
|
||||
_load_model_and_wait(second, timeout=120)
|
||||
|
||||
# Verify both models are loaded
|
||||
assert _get_model_status(first) == "loaded"
|
||||
assert _get_model_status(second) == "loaded"
|
||||
|
||||
# Load the third model - this should trigger LRU eviction of the first model
|
||||
_load_model_and_wait(third, timeout=120)
|
||||
|
||||
# Verify eviction: third is loaded, first was evicted
|
||||
assert _get_model_status(third) == "loaded"
|
||||
assert _get_model_status(first) == "unloaded"
|
||||
|
||||
|
||||
def test_router_no_models_autoload():
|
||||
global server
|
||||
server.no_models_autoload = True
|
||||
server.start()
|
||||
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||
|
||||
res = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
||||
_load_model_and_wait(model_id)
|
||||
|
||||
success_res = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert success_res.status_code == 200
|
||||
assert "error" not in success_res.body
|
||||
|
||||
|
||||
def test_router_api_key_required():
|
||||
global server
|
||||
server.api_key = "sk-router-secret"
|
||||
server.start()
|
||||
|
||||
model_id = "ggml-org/tinygemma3-GGUF:Q8_0"
|
||||
auth_headers = {"Authorization": f"Bearer {server.api_key}"}
|
||||
|
||||
res = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert res.status_code == 401
|
||||
assert res.body.get("error", {}).get("type") == "authentication_error"
|
||||
|
||||
_load_model_and_wait(model_id, headers=auth_headers)
|
||||
|
||||
authed = server.make_request(
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers=auth_headers,
|
||||
data={
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 4,
|
||||
},
|
||||
)
|
||||
assert authed.status_code == 200
|
||||
assert "error" not in authed.body
|
||||
|
||||
@@ -7,6 +7,7 @@ import subprocess
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
import sys
|
||||
import requests
|
||||
import time
|
||||
@@ -83,6 +84,9 @@ class ServerProcess:
|
||||
pooling: str | None = None
|
||||
draft: int | None = None
|
||||
api_key: str | None = None
|
||||
models_dir: str | None = None
|
||||
models_max: int | None = None
|
||||
no_models_autoload: bool | None = None
|
||||
lora_files: List[str] | None = None
|
||||
enable_ctx_shift: int | None = False
|
||||
draft_min: int | None = None
|
||||
@@ -143,6 +147,10 @@ class ServerProcess:
|
||||
server_args.extend(["--hf-repo", self.model_hf_repo])
|
||||
if self.model_hf_file:
|
||||
server_args.extend(["--hf-file", self.model_hf_file])
|
||||
if self.models_dir:
|
||||
server_args.extend(["--models-dir", self.models_dir])
|
||||
if self.models_max is not None:
|
||||
server_args.extend(["--models-max", self.models_max])
|
||||
if self.n_batch:
|
||||
server_args.extend(["--batch-size", self.n_batch])
|
||||
if self.n_ubatch:
|
||||
@@ -204,6 +212,8 @@ class ServerProcess:
|
||||
server_args.extend(["--draft-min", self.draft_min])
|
||||
if self.no_webui:
|
||||
server_args.append("--no-webui")
|
||||
if self.no_models_autoload:
|
||||
server_args.append("--no-models-autoload")
|
||||
if self.jinja:
|
||||
server_args.append("--jinja")
|
||||
else:
|
||||
@@ -295,7 +305,13 @@ class ServerProcess:
|
||||
result = ServerResponse()
|
||||
result.headers = dict(response.headers)
|
||||
result.status_code = response.status_code
|
||||
result.body = response.json() if parse_body else None
|
||||
if parse_body:
|
||||
try:
|
||||
result.body = response.json()
|
||||
except JSONDecodeError:
|
||||
result.body = response.text
|
||||
else:
|
||||
result.body = None
|
||||
print("Response from server", json.dumps(result.body, indent=2))
|
||||
return result
|
||||
|
||||
@@ -434,8 +450,9 @@ class ServerPreset:
|
||||
@staticmethod
|
||||
def tinyllama2() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "tinyllamas/stories260K.gguf"
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
server.model_hf_repo = "ggml-org/test-model-stories260K"
|
||||
server.model_hf_file = None
|
||||
server.model_alias = "tinyllama-2"
|
||||
server.n_ctx = 512
|
||||
server.n_batch = 32
|
||||
@@ -479,8 +496,8 @@ class ServerPreset:
|
||||
def tinyllama_infill() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
server.model_hf_repo = "ggml-org/models"
|
||||
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
|
||||
server.model_hf_repo = "ggml-org/test-model-stories260K-infill"
|
||||
server.model_hf_file = None
|
||||
server.model_alias = "tinyllama-infill"
|
||||
server.n_ctx = 2048
|
||||
server.n_batch = 1024
|
||||
@@ -537,6 +554,7 @@ class ServerPreset:
|
||||
@staticmethod
|
||||
def router() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
server.offline = True # will be downloaded by load_all()
|
||||
# router server has no models
|
||||
server.model_file = None
|
||||
server.model_alias = None
|
||||
|
||||
@@ -15,7 +15,7 @@ sequenceDiagram
|
||||
Stores->>DB: load conversations
|
||||
Stores->>API: GET /props
|
||||
API-->>Stores: {role: "router"}
|
||||
Stores->>API: GET /models
|
||||
Stores->>API: GET /v1/models
|
||||
API-->>Stores: models[] with status (loaded/available)
|
||||
loop each loaded model
|
||||
Stores->>API: GET /props?model=X
|
||||
@@ -28,7 +28,7 @@ sequenceDiagram
|
||||
alt model not loaded
|
||||
Stores->>API: POST /models/load
|
||||
loop poll status
|
||||
Stores->>API: GET /models
|
||||
Stores->>API: GET /v1/models
|
||||
API-->>Stores: check if loaded
|
||||
end
|
||||
Stores->>API: GET /props?model=X
|
||||
|
||||
@@ -56,7 +56,7 @@ sequenceDiagram
|
||||
UI->>modelsStore: fetchRouterModels()
|
||||
activate modelsStore
|
||||
modelsStore->>ModelsSvc: listRouter()
|
||||
ModelsSvc->>API: GET /models
|
||||
ModelsSvc->>API: GET /v1/models
|
||||
API-->>ModelsSvc: ApiRouterModelsListResponse
|
||||
Note right of API: {data: [{id, status, path, in_cache}]}
|
||||
modelsStore->>modelsStore: routerModels = $state(data)
|
||||
@@ -132,7 +132,7 @@ sequenceDiagram
|
||||
loop poll every 500ms (max 60 attempts)
|
||||
modelsStore->>modelsStore: fetchRouterModels()
|
||||
modelsStore->>ModelsSvc: listRouter()
|
||||
ModelsSvc->>API: GET /models
|
||||
ModelsSvc->>API: GET /v1/models
|
||||
API-->>ModelsSvc: models[]
|
||||
modelsStore->>modelsStore: getModelStatus(modelId)
|
||||
alt status === LOADED
|
||||
@@ -165,7 +165,7 @@ sequenceDiagram
|
||||
modelsStore->>modelsStore: pollForModelStatus(modelId, UNLOADED)
|
||||
loop poll until unloaded
|
||||
modelsStore->>ModelsSvc: listRouter()
|
||||
ModelsSvc->>API: GET /models
|
||||
ModelsSvc->>API: GET /v1/models
|
||||
end
|
||||
|
||||
modelsStore->>modelsStore: modelLoadingStates.set(modelId, false)
|
||||
|
||||
@@ -64,7 +64,10 @@
|
||||
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
|
||||
let isRecording = $state(false);
|
||||
let message = $state('');
|
||||
let pasteLongTextToFileLength = $derived(Number(currentConfig.pasteLongTextToFileLen) || 2500);
|
||||
let pasteLongTextToFileLength = $derived.by(() => {
|
||||
const n = Number(currentConfig.pasteLongTextToFileLen);
|
||||
return Number.isNaN(n) ? 2500 : n;
|
||||
});
|
||||
let previousIsLoading = $state(isLoading);
|
||||
let recordingSupported = $state(false);
|
||||
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { ChatMessage } from '$lib/components/app';
|
||||
import { DatabaseService } from '$lib/services/database';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { getMessageSiblings } from '$lib/utils';
|
||||
@@ -19,7 +18,7 @@
|
||||
const conversation = activeConversation();
|
||||
|
||||
if (conversation) {
|
||||
DatabaseService.getConversationMessages(conversation.id).then((messages) => {
|
||||
conversationsStore.getConversationMessages(conversation.id).then((messages) => {
|
||||
allConversationMessages = messages;
|
||||
});
|
||||
} else {
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
import { Textarea } from '$lib/components/ui/textarea';
|
||||
import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config';
|
||||
import { settingsStore } from '$lib/stores/settings.svelte';
|
||||
import { ParameterSyncService } from '$lib/services/parameter-sync';
|
||||
import { ChatSettingsParameterSourceIndicator } from '$lib/components/app';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
@@ -22,7 +21,7 @@
|
||||
|
||||
// Helper function to get parameter source info for syncable parameters
|
||||
function getParameterSourceInfo(key: string) {
|
||||
if (!ParameterSyncService.canSyncParameter(key)) {
|
||||
if (!settingsStore.canSyncParameter(key)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
import { Download, Upload } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { DialogConversationSelection } from '$lib/components/app';
|
||||
import { DatabaseService } from '$lib/services/database';
|
||||
import { createMessageCountMap } from '$lib/utils';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||
|
||||
let exportedConversations = $state<DatabaseConversation[]>([]);
|
||||
let importedConversations = $state<DatabaseConversation[]>([]);
|
||||
@@ -21,15 +20,15 @@
|
||||
|
||||
async function handleExportClick() {
|
||||
try {
|
||||
const allConversations = await DatabaseService.getAllConversations();
|
||||
const allConversations = conversations();
|
||||
if (allConversations.length === 0) {
|
||||
alert('No conversations to export');
|
||||
return;
|
||||
}
|
||||
|
||||
const conversationsWithMessages = await Promise.all(
|
||||
allConversations.map(async (conv) => {
|
||||
const messages = await DatabaseService.getConversationMessages(conv.id);
|
||||
allConversations.map(async (conv: DatabaseConversation) => {
|
||||
const messages = await conversationsStore.getConversationMessages(conv.id);
|
||||
return { conv, messages };
|
||||
})
|
||||
);
|
||||
@@ -47,7 +46,7 @@
|
||||
try {
|
||||
const allData: ExportedConversations = await Promise.all(
|
||||
selectedConversations.map(async (conv) => {
|
||||
const messages = await DatabaseService.getConversationMessages(conv.id);
|
||||
const messages = await conversationsStore.getConversationMessages(conv.id);
|
||||
return { conv: $state.snapshot(conv), messages: $state.snapshot(messages) };
|
||||
})
|
||||
);
|
||||
@@ -135,9 +134,7 @@
|
||||
.snapshot(fullImportData)
|
||||
.filter((item) => selectedIds.has(item.conv.id));
|
||||
|
||||
await DatabaseService.importConversations(selectedData);
|
||||
|
||||
await conversationsStore.loadConversations();
|
||||
await conversationsStore.importConversationsData(selectedData);
|
||||
|
||||
importedConversations = selectedConversations;
|
||||
showImportSummary = true;
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
import * as Table from '$lib/components/ui/table';
|
||||
import { BadgeModality, CopyToClipboardIcon } from '$lib/components/app';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { modelsStore } from '$lib/stores/models.svelte';
|
||||
import { ChatService } from '$lib/services/chat';
|
||||
import { modelsStore, modelOptions, modelsLoading } from '$lib/stores/models.svelte';
|
||||
import { formatFileSize, formatParameters, formatNumber } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
@@ -16,38 +15,24 @@
|
||||
|
||||
let serverProps = $derived(serverStore.props);
|
||||
let modelName = $derived(modelsStore.singleModelName);
|
||||
let models = $derived(modelOptions());
|
||||
let isLoadingModels = $derived(modelsLoading());
|
||||
|
||||
// Get the first model for single-model mode display
|
||||
let firstModel = $derived(models[0] ?? null);
|
||||
|
||||
// Get modalities from modelStore using the model ID from the first model
|
||||
// For now it supports only for single-model mode, will be extended with further improvements for multi-model functioanlities
|
||||
let modalities = $derived.by(() => {
|
||||
if (!modelsData?.data?.[0]?.id) return [];
|
||||
|
||||
return modelsStore.getModelModalitiesArray(modelsData.data[0].id);
|
||||
if (!firstModel?.id) return [];
|
||||
return modelsStore.getModelModalitiesArray(firstModel.id);
|
||||
});
|
||||
|
||||
let modelsData = $state<ApiModelListResponse | null>(null);
|
||||
let isLoadingModels = $state(false);
|
||||
|
||||
// Fetch models data when dialog opens
|
||||
// Ensure models are fetched when dialog opens
|
||||
$effect(() => {
|
||||
if (open && !modelsData) {
|
||||
loadModelsData();
|
||||
if (open && models.length === 0) {
|
||||
modelsStore.fetch();
|
||||
}
|
||||
});
|
||||
|
||||
async function loadModelsData() {
|
||||
isLoadingModels = true;
|
||||
|
||||
try {
|
||||
modelsData = await ChatService.getModels();
|
||||
} catch (error) {
|
||||
console.error('Failed to load models data:', error);
|
||||
// Set empty data to prevent infinite loading
|
||||
modelsData = { object: 'list', data: [] };
|
||||
} finally {
|
||||
isLoadingModels = false;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open {onOpenChange}>
|
||||
@@ -70,8 +55,8 @@
|
||||
<div class="flex items-center justify-center py-8">
|
||||
<div class="text-sm text-muted-foreground">Loading model information...</div>
|
||||
</div>
|
||||
{:else if modelsData && modelsData.data.length > 0}
|
||||
{@const modelMeta = modelsData.data[0].meta}
|
||||
{:else if firstModel}
|
||||
{@const modelMeta = firstModel.meta}
|
||||
|
||||
{#if serverProps}
|
||||
<Table.Root>
|
||||
|
||||
@@ -126,8 +126,13 @@ export const TEXT_FILE_TYPES = {
|
||||
mimeTypes: [MimeTypeText.JAVA]
|
||||
},
|
||||
[FileTypeText.CPP]: {
|
||||
extensions: [FileExtensionText.CPP, FileExtensionText.C, FileExtensionText.H],
|
||||
mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.C_SRC, MimeTypeText.C_HDR]
|
||||
extensions: [
|
||||
FileExtensionText.CPP,
|
||||
FileExtensionText.C,
|
||||
FileExtensionText.H,
|
||||
FileExtensionText.HPP
|
||||
],
|
||||
mimeTypes: [MimeTypeText.CPP_SRC, MimeTypeText.CPP_HDR, MimeTypeText.C_SRC, MimeTypeText.C_HDR]
|
||||
},
|
||||
[FileTypeText.PHP]: {
|
||||
extensions: [FileExtensionText.PHP],
|
||||
@@ -183,10 +188,30 @@ export const TEXT_FILE_TYPES = {
|
||||
},
|
||||
[FileTypeText.LATEX]: {
|
||||
extensions: [FileExtensionText.TEX],
|
||||
mimeTypes: [MimeTypeText.LATEX]
|
||||
mimeTypes: [MimeTypeText.LATEX, MimeTypeText.TEX, MimeTypeText.TEX_APP]
|
||||
},
|
||||
[FileTypeText.BIBTEX]: {
|
||||
extensions: [FileExtensionText.BIB],
|
||||
mimeTypes: [MimeTypeText.BIBTEX]
|
||||
},
|
||||
[FileTypeText.CUDA]: {
|
||||
extensions: [FileExtensionText.CU, FileExtensionText.CUH],
|
||||
mimeTypes: [MimeTypeText.CUDA]
|
||||
},
|
||||
[FileTypeText.VULKAN]: {
|
||||
extensions: [FileExtensionText.COMP],
|
||||
mimeTypes: [MimeTypeText.PLAIN]
|
||||
},
|
||||
[FileTypeText.HASKELL]: {
|
||||
extensions: [FileExtensionText.HS],
|
||||
mimeTypes: [MimeTypeText.HASKELL]
|
||||
},
|
||||
[FileTypeText.CSHARP]: {
|
||||
extensions: [FileExtensionText.CS],
|
||||
mimeTypes: [MimeTypeText.CSHARP]
|
||||
},
|
||||
[FileTypeText.PROPERTIES]: {
|
||||
extensions: [FileExtensionText.PROPERTIES],
|
||||
mimeTypes: [MimeTypeText.PROPERTIES]
|
||||
}
|
||||
} as const;
|
||||
|
||||
@@ -62,7 +62,12 @@ export enum FileTypeText {
|
||||
VUE = 'vue',
|
||||
SVELTE = 'svelte',
|
||||
LATEX = 'latex',
|
||||
BIBTEX = 'bibtex'
|
||||
BIBTEX = 'bibtex',
|
||||
CUDA = 'cuda',
|
||||
VULKAN = 'vulkan',
|
||||
HASKELL = 'haskell',
|
||||
CSHARP = 'csharp',
|
||||
PROPERTIES = 'properties'
|
||||
}
|
||||
|
||||
// File extension enums
|
||||
@@ -121,7 +126,14 @@ export enum FileExtensionText {
|
||||
VUE = '.vue',
|
||||
SVELTE = '.svelte',
|
||||
TEX = '.tex',
|
||||
BIB = '.bib'
|
||||
BIB = '.bib',
|
||||
CU = '.cu',
|
||||
CUH = '.cuh',
|
||||
COMP = '.comp',
|
||||
HPP = '.hpp',
|
||||
HS = '.hs',
|
||||
PROPERTIES = '.properties',
|
||||
CS = '.cs'
|
||||
}
|
||||
|
||||
// MIME type enums
|
||||
@@ -165,7 +177,10 @@ export enum MimeTypeText {
|
||||
CSV = 'text/csv',
|
||||
PYTHON = 'text/x-python',
|
||||
JAVA = 'text/x-java-source',
|
||||
CPP_HDR = 'text/x-c++hdr',
|
||||
CPP_SRC = 'text/x-c++src',
|
||||
CSHARP = 'text/x-csharp',
|
||||
HASKELL = 'text/x-haskell',
|
||||
C_SRC = 'text/x-csrc',
|
||||
C_HDR = 'text/x-chdr',
|
||||
PHP = 'text/x-php',
|
||||
@@ -182,6 +197,10 @@ export enum MimeTypeText {
|
||||
DART = 'text/x-dart',
|
||||
VUE = 'text/x-vue',
|
||||
SVELTE = 'text/x-svelte',
|
||||
LATEX = 'text/x-tex',
|
||||
BIBTEX = 'text/x-bibtex'
|
||||
TEX = 'text/x-tex',
|
||||
TEX_APP = 'application/x-tex',
|
||||
LATEX = 'application/x-latex',
|
||||
BIBTEX = 'text/x-bibtex',
|
||||
CUDA = 'text/x-cuda',
|
||||
PROPERTIES = 'text/properties'
|
||||
}
|
||||
|
||||
@@ -677,48 +677,6 @@ export class ChatService {
|
||||
// Utilities
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Get server properties - static method for API compatibility (to be refactored)
|
||||
*/
|
||||
static async getServerProps(): Promise<ApiLlamaCppServerProps> {
|
||||
try {
|
||||
const response = await fetch(`./props`, {
|
||||
headers: getJsonHeaders()
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch server props: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data;
|
||||
} catch (error) {
|
||||
console.error('Error fetching server props:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get model information from /models endpoint (to be refactored)
|
||||
*/
|
||||
static async getModels(): Promise<ApiModelListResponse> {
|
||||
try {
|
||||
const response = await fetch(`./models`, {
|
||||
headers: getJsonHeaders()
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch models: ${response.status} ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data;
|
||||
} catch (error) {
|
||||
console.error('Error fetching models:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Injects a system message at the beginning of the conversation if provided.
|
||||
* Checks for existing system messages to avoid duplication.
|
||||
|
||||
@@ -7,7 +7,7 @@ import { getJsonHeaders } from '$lib/utils';
|
||||
*
|
||||
* This service handles communication with model-related endpoints:
|
||||
* - `/v1/models` - OpenAI-compatible model list (MODEL + ROUTER mode)
|
||||
* - `/models` - Router-specific model management (ROUTER mode only)
|
||||
* - `/models/load`, `/models/unload` - Router-specific model management (ROUTER mode only)
|
||||
*
|
||||
* **Responsibilities:**
|
||||
* - List available models
|
||||
@@ -43,7 +43,7 @@ export class ModelsService {
|
||||
* Returns models with load status, paths, and other metadata
|
||||
*/
|
||||
static async listRouter(): Promise<ApiRouterModelsListResponse> {
|
||||
const response = await fetch(`${base}/models`, {
|
||||
const response = await fetch(`${base}/v1/models`, {
|
||||
headers: getJsonHeaders()
|
||||
});
|
||||
|
||||
|
||||
@@ -519,6 +519,19 @@ class ConversationsStore {
|
||||
return await DatabaseService.getConversationMessages(convId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Imports conversations from provided data (without file picker)
|
||||
* @param data - Array of conversation data with messages
|
||||
* @returns Import result with counts
|
||||
*/
|
||||
async importConversationsData(
|
||||
data: ExportedConversations
|
||||
): Promise<{ imported: number; skipped: number }> {
|
||||
const result = await DatabaseService.importConversations(data);
|
||||
await this.loadConversations();
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a message to the active messages array
|
||||
* Used by chatStore when creating new messages
|
||||
|
||||
@@ -370,6 +370,10 @@ class SettingsStore {
|
||||
return { ...this.config };
|
||||
}
|
||||
|
||||
canSyncParameter(key: string): boolean {
|
||||
return ParameterSyncService.canSyncParameter(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get parameter information including source for a specific parameter
|
||||
*/
|
||||
|
||||
@@ -77,6 +77,13 @@ export function getFileTypeCategory(mimeType: string): FileTypeCategory | null {
|
||||
case MimeTypeText.SVELTE:
|
||||
case MimeTypeText.LATEX:
|
||||
case MimeTypeText.BIBTEX:
|
||||
case MimeTypeText.CUDA:
|
||||
case MimeTypeText.CPP_HDR:
|
||||
case MimeTypeText.CSHARP:
|
||||
case MimeTypeText.HASKELL:
|
||||
case MimeTypeText.PROPERTIES:
|
||||
case MimeTypeText.TEX:
|
||||
case MimeTypeText.TEX_APP:
|
||||
return FileTypeCategory.TEXT;
|
||||
|
||||
default:
|
||||
@@ -144,6 +151,12 @@ export function getFileTypeCategoryByExtension(filename: string): FileTypeCatego
|
||||
case FileExtensionText.SVELTE:
|
||||
case FileExtensionText.TEX:
|
||||
case FileExtensionText.BIB:
|
||||
case FileExtensionText.COMP:
|
||||
case FileExtensionText.CU:
|
||||
case FileExtensionText.CUH:
|
||||
case FileExtensionText.HPP:
|
||||
case FileExtensionText.HS:
|
||||
case FileExtensionText.PROPERTIES:
|
||||
return FileTypeCategory.TEXT;
|
||||
|
||||
default:
|
||||
|
||||
3
vendor/cpp-httplib/CMakeLists.txt
vendored
3
vendor/cpp-httplib/CMakeLists.txt
vendored
@@ -144,4 +144,7 @@ if (CPPHTTPLIB_OPENSSL_SUPPORT)
|
||||
find_library(SECURITY_FRAMEWORK Security REQUIRED)
|
||||
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
|
||||
endif()
|
||||
if (WIN32 AND NOT MSVC)
|
||||
target_link_libraries(${TARGET} PUBLIC crypt32)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
Reference in New Issue
Block a user