mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-05 16:44:11 +00:00
Compare commits
87 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be0239693c | ||
|
|
a4090d1174 | ||
|
|
b69f1647f9 | ||
|
|
759e37b0d8 | ||
|
|
4245e622e0 | ||
|
|
c9c64dee57 | ||
|
|
c00a2634be | ||
|
|
e298d2fbd0 | ||
|
|
f0adb80bf7 | ||
|
|
f7c9429c85 | ||
|
|
1dfbf2cf3a | ||
|
|
8960efd0a6 | ||
|
|
725f23f1f3 | ||
|
|
92ecdcc06a | ||
|
|
f71f40a284 | ||
|
|
d30cb5a7fa | ||
|
|
6c35981a64 | ||
|
|
8b5e19aea6 | ||
|
|
60aea028b5 | ||
|
|
9c55e5c5c2 | ||
|
|
33d7aed4a8 | ||
|
|
6a2bc8bfb7 | ||
|
|
e3a7cf6c5b | ||
|
|
518329b2d4 | ||
|
|
2f5a4e1e09 | ||
|
|
4f41ee11d6 | ||
|
|
3e0be1cace | ||
|
|
6aa892ec2a | ||
|
|
aea9f8b4e7 | ||
|
|
06c1e4abc1 | ||
|
|
415e40a357 | ||
|
|
654a67794f | ||
|
|
5364ae4ba5 | ||
|
|
7c07ac244d | ||
|
|
0a338ed013 | ||
|
|
bc098c3cf0 | ||
|
|
c6a2c9e741 | ||
|
|
07ad2b6db3 | ||
|
|
c531edfa34 | ||
|
|
02cdd2d8b0 | ||
|
|
64bb51cf90 | ||
|
|
9c404ed54c | ||
|
|
6c8b91500e | ||
|
|
3cc1f1f1d2 | ||
|
|
c753d7bed0 | ||
|
|
b2838049cc | ||
|
|
aa48e373f2 | ||
|
|
e3a9421b78 | ||
|
|
5ab5d5fb25 | ||
|
|
3198405e98 | ||
|
|
f5170c1d7a | ||
|
|
017f10b5fa | ||
|
|
4696d56749 | ||
|
|
b7d2672082 | ||
|
|
6da34fa276 | ||
|
|
5e7d95e22e | ||
|
|
053174436f | ||
|
|
360a9c98e1 | ||
|
|
09d13d94fb | ||
|
|
24e86cae72 | ||
|
|
bb1681fbd5 | ||
|
|
d486dd3e8e | ||
|
|
21ca987fba | ||
|
|
be1d4a13db | ||
|
|
ab3971f2a0 | ||
|
|
e5c834f718 | ||
|
|
71bdbdb587 | ||
|
|
f0995d28ce | ||
|
|
c252e0c409 | ||
|
|
4f711afed5 | ||
|
|
b89d605a91 | ||
|
|
b4726345ac | ||
|
|
bf79371120 | ||
|
|
d590cd4c24 | ||
|
|
1e2809bc4b | ||
|
|
cf0a43bb64 | ||
|
|
f0d46ef157 | ||
|
|
de4c07f937 | ||
|
|
10d2af0eaa | ||
|
|
064cc596ac | ||
|
|
91159ee9df | ||
|
|
22cdab343b | ||
|
|
a71a4075cd | ||
|
|
95e18884fc | ||
|
|
df8491922f | ||
|
|
14492144c2 | ||
|
|
c104023994 |
@@ -1,4 +1,4 @@
|
||||
ARG ONEAPI_VERSION=2025.0.0-0-devel-ubuntu22.04
|
||||
ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04
|
||||
|
||||
## Build Image
|
||||
|
||||
|
||||
@@ -5,6 +5,10 @@ inputs:
|
||||
description: 'CURL version'
|
||||
required: false
|
||||
default: '8.6.0_6'
|
||||
architecture:
|
||||
description: 'Architecture of the libcurl to download'
|
||||
required: false
|
||||
default: 'win64'
|
||||
outputs:
|
||||
curl_path:
|
||||
description: "Path to the downloaded libcurl"
|
||||
@@ -18,8 +22,9 @@ runs:
|
||||
shell: powershell
|
||||
env:
|
||||
CURL_VERSION: ${{ inputs.curl_version }}
|
||||
ARCHITECTURE: ${{ inputs.architecture }}
|
||||
run: |
|
||||
curl.exe -o $env:RUNNER_TEMP/curl.zip -L "https://curl.se/windows/dl-${env:CURL_VERSION}/curl-${env:CURL_VERSION}-win64-mingw.zip"
|
||||
curl.exe -o $env:RUNNER_TEMP/curl.zip -L "https://curl.se/windows/dl-${env:CURL_VERSION}/curl-${env:CURL_VERSION}-${env:ARCHITECTURE}-mingw.zip"
|
||||
mkdir $env:RUNNER_TEMP/libcurl
|
||||
tar.exe -xvf $env:RUNNER_TEMP/curl.zip --strip-components=1 -C $env:RUNNER_TEMP/libcurl
|
||||
echo "curl_path=$env:RUNNER_TEMP/libcurl" >> $env:GITHUB_OUTPUT
|
||||
|
||||
91
.github/workflows/build-linux-cross.yml
vendored
91
.github/workflows/build-linux-cross.yml
vendored
@@ -140,3 +140,94 @@ jobs:
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
ubuntu-24-ppc64el-cpu-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup PowerPC64le
|
||||
run: |
|
||||
sudo dpkg --add-architecture ppc64el
|
||||
|
||||
# Add arch-specific repositories for non-amd64 architectures
|
||||
cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
EOF
|
||||
|
||||
sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
gcc-14-powerpc64le-linux-gnu \
|
||||
g++-14-powerpc64le-linux-gnu \
|
||||
libcurl4-openssl-dev:ppc64el
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=ppc64 \
|
||||
-DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \
|
||||
-DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
ubuntu-24-ppc64el-vulkan-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup PowerPC64le
|
||||
run: |
|
||||
sudo dpkg --add-architecture ppc64el
|
||||
|
||||
# Add arch-specific repositories for non-amd64 architectures
|
||||
cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
EOF
|
||||
|
||||
sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
glslc \
|
||||
gcc-14-powerpc64le-linux-gnu \
|
||||
g++-14-powerpc64le-linux-gnu \
|
||||
libvulkan-dev:ppc64el \
|
||||
libcurl4-openssl-dev:ppc64el
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_VULKAN=ON \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=ppc64 \
|
||||
-DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \
|
||||
-DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -899,7 +899,7 @@ jobs:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/b380d914-366b-4b77-a74a-05e3c38b3514/intel-oneapi-base-toolkit-2025.0.0.882_offline.exe
|
||||
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe
|
||||
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel
|
||||
ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI"
|
||||
steps:
|
||||
|
||||
11
.github/workflows/release.yml
vendored
11
.github/workflows/release.yml
vendored
@@ -238,14 +238,19 @@ jobs:
|
||||
matrix:
|
||||
include:
|
||||
- build: 'cpu-x64'
|
||||
arch: 'x64'
|
||||
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF'
|
||||
#- build: 'openblas-x64'
|
||||
# arch: 'x64'
|
||||
# defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"'
|
||||
- build: 'vulkan-x64'
|
||||
arch: 'x64'
|
||||
defines: '-DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON'
|
||||
- build: 'cpu-arm64'
|
||||
arch: 'arm64'
|
||||
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF'
|
||||
- build: 'opencl-adreno-arm64'
|
||||
arch: 'arm64'
|
||||
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON'
|
||||
|
||||
steps:
|
||||
@@ -312,6 +317,8 @@ jobs:
|
||||
- name: libCURL
|
||||
id: get_libcurl
|
||||
uses: ./.github/actions/windows-setup-curl
|
||||
with:
|
||||
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -339,7 +346,7 @@ jobs:
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
Copy-Item $env:CURL_PATH\bin\libcurl-x64.dll .\build\bin\Release\libcurl-x64.dll
|
||||
Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\
|
||||
7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\*
|
||||
|
||||
- name: Upload artifacts
|
||||
@@ -441,7 +448,7 @@ jobs:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/b380d914-366b-4b77-a74a-05e3c38b3514/intel-oneapi-base-toolkit-2025.0.0.882_offline.exe
|
||||
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe
|
||||
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel
|
||||
ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI"
|
||||
steps:
|
||||
|
||||
@@ -572,4 +572,11 @@ automatically. For example:
|
||||
$ echo "source ~/.llama-completion.bash" >> ~/.bashrc
|
||||
```
|
||||
|
||||
## References
|
||||
## Dependencies
|
||||
|
||||
- [yhirose/cpp-httplib](https://github.com/yhirose/cpp-httplib) - Single-header HTTP server, used by `llama-server` - MIT license
|
||||
- [stb-image](https://github.com/nothings/stb) - Single-header image format decoder, used by multimodal subsystem - Public domain
|
||||
- [nlohmann/json](https://github.com/nlohmann/json) - Single-header JSON library, used by various tools/examples - MIT License
|
||||
- [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License
|
||||
- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License
|
||||
- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html)
|
||||
|
||||
@@ -117,6 +117,7 @@ setup_framework_structure() {
|
||||
# Copy all required headers (common for all platforms)
|
||||
cp include/llama.h ${header_path}
|
||||
cp ggml/include/ggml.h ${header_path}
|
||||
cp ggml/include/ggml-opt.h ${header_path}
|
||||
cp ggml/include/ggml-alloc.h ${header_path}
|
||||
cp ggml/include/ggml-backend.h ${header_path}
|
||||
cp ggml/include/ggml-metal.h ${header_path}
|
||||
|
||||
@@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
|
||||
minja/minja.hpp
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
regex-partial.cpp
|
||||
regex-partial.h
|
||||
sampling.cpp
|
||||
sampling.h
|
||||
speculative.cpp
|
||||
@@ -119,8 +121,8 @@ if (LLAMA_LLGUIDANCE)
|
||||
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
# v0.7.19 (+ fancy-regex build fix):
|
||||
GIT_TAG b59f98f85269892a7de3d3641ad155366f13daa6
|
||||
# v0.7.20 (+ fix to build on GCC 15):
|
||||
GIT_TAG b5b8b64dba11c4e4ee6b1d1450d3a3ae279891e8
|
||||
PREFIX ${CMAKE_BINARY_DIR}/llguidance
|
||||
SOURCE_DIR ${LLGUIDANCE_SRC}
|
||||
BUILD_IN_SOURCE TRUE
|
||||
|
||||
@@ -1445,6 +1445,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.n_keep = value;
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--swa-full"},
|
||||
string_format("use full-size SWA cache (default: %s)\n"
|
||||
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)", params.swa_full ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.swa_full = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_SWA_FULL"));
|
||||
add_opt(common_arg(
|
||||
{"--no-context-shift"},
|
||||
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
@@ -2057,13 +2065,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.grp_attn_w = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_GRP_ATTN_W").set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||
add_opt(common_arg(
|
||||
{"-dkvc", "--dump-kv-cache"},
|
||||
"verbose print of the KV cache",
|
||||
[](common_params & params) {
|
||||
params.dump_kv_cache = true;
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"-nkvo", "--no-kv-offload"},
|
||||
"disable KV offload",
|
||||
@@ -2585,7 +2586,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, int value) {
|
||||
params.n_junk = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_PASSKEY}));
|
||||
).set_examples({LLAMA_EXAMPLE_PASSKEY, LLAMA_EXAMPLE_PARALLEL}));
|
||||
add_opt(common_arg(
|
||||
{"--pos"}, "N",
|
||||
string_format("position of the passkey in the junk text (default: %d)", params.i_pos),
|
||||
@@ -2648,7 +2649,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.is_pp_shared = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_BENCH}));
|
||||
).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
|
||||
add_opt(common_arg(
|
||||
{"-npp"}, "n0,n1,...",
|
||||
"number of prompt tokens",
|
||||
@@ -2880,6 +2881,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.chat_template = read_file(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"--no-prefill-assistant"},
|
||||
string_format(
|
||||
"whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)\n"
|
||||
"when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled\n"
|
||||
),
|
||||
[](common_params & params) {
|
||||
params.prefill_assistant = false;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_PREFILL_ASSISTANT"));
|
||||
add_opt(common_arg(
|
||||
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
|
||||
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
|
||||
|
||||
240
common/chat.cpp
240
common/chat.cpp
@@ -6,6 +6,15 @@
|
||||
|
||||
#include <optional>
|
||||
|
||||
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
|
||||
auto time = std::chrono::system_clock::to_time_t(now);
|
||||
auto local_time = *std::localtime(&time);
|
||||
std::ostringstream ss;
|
||||
ss << std::put_time(&local_time, format.c_str());
|
||||
auto res = ss.str();
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef minja::chat_template common_chat_template;
|
||||
|
||||
struct common_chat_templates {
|
||||
@@ -24,6 +33,7 @@ struct templates_params {
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = true;
|
||||
bool extract_reasoning = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
||||
@@ -939,78 +949,83 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
||||
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
||||
auto builtin_tools = json::array();
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
if (!inputs.tools.is_null()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "python" || name == "code_interpreter") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
||||
expect_tool_parameters(name, parameters, {"code"});
|
||||
} else {
|
||||
return false;
|
||||
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "python" || name == "code_interpreter") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
||||
expect_tool_parameters(name, parameters, {"code"});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> kvs;
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
||||
}
|
||||
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
||||
builtin_tools.push_back(name);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||
if (allow_python_tag_builtin_tools) {
|
||||
handle_builtin_tool(name, parameters);
|
||||
}
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"{\" space "
|
||||
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
||||
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
||||
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"}\" space"));
|
||||
});
|
||||
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
||||
});
|
||||
if (!builtin_tools.empty()) {
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
|
||||
std::vector<std::string> kvs;
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
||||
}
|
||||
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
||||
builtin_tools.push_back(name);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||
if (allow_python_tag_builtin_tools) {
|
||||
handle_builtin_tool(name, parameters);
|
||||
}
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"{\" space "
|
||||
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
||||
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
||||
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"}\" space"));
|
||||
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
});
|
||||
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
||||
});
|
||||
if (!builtin_tools.empty()) {
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
});
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
||||
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
||||
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||
} else {
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
}
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||
{"date_string", format_time(inputs.now, "%d %b %Y")},
|
||||
{"tools_in_user_message", false},
|
||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
});
|
||||
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
||||
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
||||
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
||||
@@ -1150,7 +1165,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
});
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
@@ -1285,55 +1300,59 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
common_chat_params data;
|
||||
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
||||
std::string python_code_argument_name;
|
||||
auto has_raw_python = false;
|
||||
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
const auto & parameters = function.at("parameters");
|
||||
std::string name = function.at("name");
|
||||
if (name == "python" || name == "ipython") {
|
||||
if (!parameters.contains("type")) {
|
||||
throw std::runtime_error("Missing type in python tool");
|
||||
}
|
||||
has_raw_python = true;
|
||||
const auto & type = parameters.at("type");
|
||||
if (type == "object") {
|
||||
auto properties = parameters.at("properties");
|
||||
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||
if (it.value().at("type") == "string") {
|
||||
if (!python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("Multiple string arguments found in python tool");
|
||||
if (!inputs.tools.is_null()) {
|
||||
std::string python_code_argument_name;
|
||||
auto has_raw_python = false;
|
||||
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
const auto & parameters = function.at("parameters");
|
||||
std::string name = function.at("name");
|
||||
if (name == "python" || name == "ipython") {
|
||||
if (!parameters.contains("type")) {
|
||||
throw std::runtime_error("Missing type in python tool");
|
||||
}
|
||||
has_raw_python = true;
|
||||
const auto & type = parameters.at("type");
|
||||
if (type == "object") {
|
||||
auto properties = parameters.at("properties");
|
||||
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||
if (it.value().at("type") == "string") {
|
||||
if (!python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("Multiple string arguments found in python tool");
|
||||
}
|
||||
python_code_argument_name = it.key();
|
||||
}
|
||||
python_code_argument_name = it.key();
|
||||
}
|
||||
if (python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("No string argument found in python tool");
|
||||
}
|
||||
} else if (type != "string") {
|
||||
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
||||
}
|
||||
if (python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("No string argument found in python tool");
|
||||
}
|
||||
} else if (type != "string") {
|
||||
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
||||
}
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
});
|
||||
if (has_raw_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
||||
});
|
||||
if (has_raw_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
||||
});
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||
} else {
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
// TODO: if (has_raw_python)
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
||||
@@ -1593,6 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
params.extract_reasoning = inputs.extract_reasoning;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.json_schema = json::parse(inputs.json_schema);
|
||||
}
|
||||
@@ -1644,21 +1664,21 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_firefunction_v2(tmpl, params);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
// Functionary v3.1 (w/ tools)
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
||||
}
|
||||
|
||||
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
||||
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
||||
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
// Mistral Nemo (w/ tools)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -71,6 +72,7 @@ struct common_chat_templates_inputs {
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
bool extract_reasoning = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
struct common_chat_params {
|
||||
|
||||
@@ -443,6 +443,25 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const char text_last_char = str.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const auto current_partial = stop.substr(0, char_index + 1);
|
||||
if (string_ends_with(str, current_partial)) {
|
||||
return str.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string & s) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
return std::regex_replace(s, special_chars, "\\$0");
|
||||
@@ -1083,6 +1102,9 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
|
||||
}
|
||||
|
||||
mparams.progress_callback = params.load_progress_callback;
|
||||
mparams.progress_callback_user_data = params.load_progress_callback_user_data;
|
||||
|
||||
return mparams;
|
||||
}
|
||||
|
||||
@@ -1114,6 +1136,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.op_offload = !params.no_op_offload;
|
||||
cparams.swa_full = params.swa_full;
|
||||
|
||||
if (params.reranking) {
|
||||
cparams.embeddings = true;
|
||||
@@ -1306,81 +1329,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
|
||||
return text;
|
||||
}
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
||||
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
|
||||
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
|
||||
|
||||
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
|
||||
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
|
||||
|
||||
llama_kv_cache_view_cell * c_curr = view.cells;
|
||||
llama_seq_id * cs_curr = view.cells_sequences;
|
||||
|
||||
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
||||
if (i % row_size == 0) {
|
||||
printf("\n%5d: ", i);
|
||||
}
|
||||
int seq_count = 0;
|
||||
for (int j = 0; j < view.n_seq_max; j++) {
|
||||
if (cs_curr[j] >= 0) { seq_count++; }
|
||||
}
|
||||
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
|
||||
}
|
||||
|
||||
printf("\n=== Done dumping\n");
|
||||
}
|
||||
|
||||
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) {
|
||||
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
|
||||
|
||||
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
|
||||
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
|
||||
|
||||
std::unordered_map<llama_seq_id, size_t> seqs;
|
||||
llama_kv_cache_view_cell * c_curr = view.cells;
|
||||
llama_seq_id * cs_curr = view.cells_sequences;
|
||||
|
||||
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
||||
for (int j = 0; j < view.n_seq_max; j++) {
|
||||
if (cs_curr[j] < 0) { continue; }
|
||||
if (seqs.find(cs_curr[j]) == seqs.end()) {
|
||||
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
|
||||
const size_t sz = seqs.size();
|
||||
seqs[cs_curr[j]] = sz;
|
||||
}
|
||||
}
|
||||
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
|
||||
}
|
||||
|
||||
printf("=== Sequence legend: ");
|
||||
for (const auto & it : seqs) {
|
||||
printf("%zu=%d, ", it.second, it.first);
|
||||
}
|
||||
printf("'+'=other sequence ids");
|
||||
|
||||
c_curr = view.cells;
|
||||
cs_curr = view.cells_sequences;
|
||||
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
||||
if (i % row_size == 0) {
|
||||
printf("\n%5d: ", i);
|
||||
}
|
||||
for (int j = 0; j < view.n_seq_max; j++) {
|
||||
if (cs_curr[j] >= 0) {
|
||||
const auto & it = seqs.find(cs_curr[j]);
|
||||
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
|
||||
} else {
|
||||
putchar('.');
|
||||
}
|
||||
}
|
||||
putchar(' ');
|
||||
}
|
||||
|
||||
printf("\n=== Done dumping\n");
|
||||
}
|
||||
|
||||
//
|
||||
// Embedding utils
|
||||
//
|
||||
@@ -1565,3 +1513,20 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
|
||||
const int64_t ne_datapoint = llama_n_ctx(ctx);
|
||||
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
|
||||
ggml_opt_dataset_t result = ggml_opt_dataset_init(
|
||||
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
|
||||
|
||||
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
|
||||
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
|
||||
|
||||
for (int64_t idata = 0; idata < ndata; ++idata) {
|
||||
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
|
||||
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
@@ -322,13 +323,13 @@ struct common_params {
|
||||
bool flash_attn = false; // flash attention
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool ctx_shift = true; // context shift on inifinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
bool display_prompt = true; // print prompt before generation
|
||||
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
|
||||
bool no_kv_offload = false; // disable KV offloading
|
||||
bool warmup = true; // warmup run
|
||||
bool check_tensors = false; // validate tensor data
|
||||
@@ -367,6 +368,7 @@ struct common_params {
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
|
||||
@@ -426,6 +428,11 @@ struct common_params {
|
||||
|
||||
// common params
|
||||
std::string out_file; // output filename for all example programs
|
||||
// optional callback for model loading progress and cancellation:
|
||||
// called with a progress value between 0.0 and 1.0.
|
||||
// return false from callback to abort model loading or true to continue
|
||||
llama_progress_callback load_progress_callback = NULL;
|
||||
void * load_progress_callback_user_data = NULL;
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
@@ -503,10 +510,9 @@ static bool string_starts_with(const std::string & str,
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
||||
|
||||
static bool string_ends_with(const std::string & str,
|
||||
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
// While we wait for C++20's std::string::ends_with...
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||
void string_process_escapes(std::string & input);
|
||||
@@ -615,16 +621,6 @@ std::string common_detokenize(
|
||||
const std::vector<llama_token> & tokens,
|
||||
bool special = true);
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
||||
// Dump the KV cache view with the number of sequences per cell.
|
||||
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
|
||||
|
||||
// Dump the KV cache view showing individual sequences in each cell (long output).
|
||||
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
|
||||
|
||||
//
|
||||
// Embedding utils
|
||||
//
|
||||
@@ -666,3 +662,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
|
||||
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// training utils
|
||||
//
|
||||
|
||||
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
|
||||
|
||||
@@ -13,10 +13,12 @@
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <exception>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -393,8 +395,8 @@ class chat_template {
|
||||
|
||||
for (const auto & message_ : adjusted_messages) {
|
||||
auto message = message_;
|
||||
if (!message.contains("role") || !message.contains("content")) {
|
||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
|
||||
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
|
||||
}
|
||||
std::string role = message.at("role");
|
||||
|
||||
@@ -415,7 +417,6 @@ class chat_template {
|
||||
}
|
||||
}
|
||||
if (polyfill_tool_calls) {
|
||||
auto content = message.at("content");
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call.at("type") != "function") {
|
||||
@@ -434,8 +435,11 @@ class chat_template {
|
||||
auto obj = json {
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
if (!content.is_null() && !content.empty()) {
|
||||
obj["content"] = content;
|
||||
if (message.contains("content")) {
|
||||
auto content = message.at("content");
|
||||
if (!content.is_null() && !content.empty()) {
|
||||
obj["content"] = content;
|
||||
}
|
||||
}
|
||||
message["content"] = obj.dump(2);
|
||||
message.erase("tool_calls");
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cmath>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
@@ -233,7 +234,7 @@ public:
|
||||
}
|
||||
} else if (is_object()) {
|
||||
if (!index.is_hashable())
|
||||
throw std::runtime_error("Unashable type: " + index.dump());
|
||||
throw std::runtime_error("Unhashable type: " + index.dump());
|
||||
auto it = object_->find(index.primitive_);
|
||||
if (it == object_->end())
|
||||
throw std::runtime_error("Key not found: " + index.dump());
|
||||
@@ -252,7 +253,7 @@ public:
|
||||
auto index = key.get<int>();
|
||||
return array_->at(index < 0 ? array_->size() + index : index);
|
||||
} else if (object_) {
|
||||
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
|
||||
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
|
||||
auto it = object_->find(key.primitive_);
|
||||
if (it == object_->end()) return Value();
|
||||
return it->second;
|
||||
@@ -261,7 +262,7 @@ public:
|
||||
}
|
||||
void set(const Value& key, const Value& value) {
|
||||
if (!object_) throw std::runtime_error("Value is not an object: " + dump());
|
||||
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
|
||||
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
|
||||
(*object_)[key.primitive_] = value;
|
||||
}
|
||||
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
|
||||
@@ -398,7 +399,7 @@ public:
|
||||
}
|
||||
return false;
|
||||
} else if (object_) {
|
||||
if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
|
||||
if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump());
|
||||
return object_->find(value.primitive_) != object_->end();
|
||||
} else {
|
||||
throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
|
||||
@@ -416,7 +417,7 @@ public:
|
||||
return const_cast<Value*>(this)->at(index);
|
||||
}
|
||||
Value& at(const Value & index) {
|
||||
if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
|
||||
if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
|
||||
if (is_array()) return array_->at(index.get<int>());
|
||||
if (is_object()) return object_->at(index.primitive_);
|
||||
throw std::runtime_error("Value is not an array or object: " + dump());
|
||||
@@ -676,8 +677,8 @@ public:
|
||||
class VariableExpr : public Expression {
|
||||
std::string name;
|
||||
public:
|
||||
VariableExpr(const Location & location, const std::string& n)
|
||||
: Expression(location), name(n) {}
|
||||
VariableExpr(const Location & loc, const std::string& n)
|
||||
: Expression(loc), name(n) {}
|
||||
std::string get_name() const { return name; }
|
||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||
if (!context->contains(name)) {
|
||||
@@ -1200,9 +1201,9 @@ public:
|
||||
|
||||
class SliceExpr : public Expression {
|
||||
public:
|
||||
std::shared_ptr<Expression> start, end;
|
||||
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
|
||||
: Expression(loc), start(std::move(s)), end(std::move(e)) {}
|
||||
std::shared_ptr<Expression> start, end, step;
|
||||
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e, std::shared_ptr<Expression> && st = nullptr)
|
||||
: Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> &) const override {
|
||||
throw std::runtime_error("SliceExpr not implemented");
|
||||
}
|
||||
@@ -1219,18 +1220,35 @@ public:
|
||||
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
|
||||
auto target_value = base->evaluate(context);
|
||||
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
|
||||
auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
|
||||
auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
|
||||
auto len = target_value.size();
|
||||
auto wrap = [len](int64_t i) -> int64_t {
|
||||
if (i < 0) {
|
||||
return i + len;
|
||||
}
|
||||
return i;
|
||||
};
|
||||
int64_t step = slice->step ? slice->step->evaluate(context).get<int64_t>() : 1;
|
||||
if (!step) {
|
||||
throw std::runtime_error("slice step cannot be zero");
|
||||
}
|
||||
int64_t start = slice->start ? wrap(slice->start->evaluate(context).get<int64_t>()) : (step < 0 ? len - 1 : 0);
|
||||
int64_t end = slice->end ? wrap(slice->end->evaluate(context).get<int64_t>()) : (step < 0 ? -1 : len);
|
||||
if (target_value.is_string()) {
|
||||
std::string s = target_value.get<std::string>();
|
||||
if (start < 0) start = s.size() + start;
|
||||
if (end < 0) end = s.size() + end;
|
||||
return s.substr(start, end - start);
|
||||
|
||||
std::string result;
|
||||
if (start < end && step == 1) {
|
||||
result = s.substr(start, end - start);
|
||||
} else {
|
||||
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
|
||||
result += s[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
||||
} else if (target_value.is_array()) {
|
||||
if (start < 0) start = target_value.size() + start;
|
||||
if (end < 0) end = target_value.size() + end;
|
||||
auto result = Value::array();
|
||||
for (auto i = start; i < end; ++i) {
|
||||
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
|
||||
result.push_back(target_value.at(i));
|
||||
}
|
||||
return result;
|
||||
@@ -1305,6 +1323,8 @@ public:
|
||||
if (name == "iterable") return l.is_iterable();
|
||||
if (name == "sequence") return l.is_array();
|
||||
if (name == "defined") return !l.is_null();
|
||||
if (name == "true") return l.to_bool();
|
||||
if (name == "false") return !l.to_bool();
|
||||
throw std::runtime_error("Unknown type for 'is' operator: " + name);
|
||||
};
|
||||
auto value = eval();
|
||||
@@ -1520,6 +1540,10 @@ public:
|
||||
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
|
||||
auto suffix = vargs.args[0].get<std::string>();
|
||||
return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
|
||||
} else if (method->get_name() == "startswith") {
|
||||
vargs.expectArgs("startswith method", {1, 1}, {0, 0});
|
||||
auto prefix = vargs.args[0].get<std::string>();
|
||||
return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin());
|
||||
} else if (method->get_name() == "title") {
|
||||
vargs.expectArgs("title method", {0, 0}, {0, 0});
|
||||
auto res = str;
|
||||
@@ -2082,28 +2106,37 @@ private:
|
||||
|
||||
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
|
||||
if (!consumeToken("[").empty()) {
|
||||
std::shared_ptr<Expression> index;
|
||||
std::shared_ptr<Expression> index;
|
||||
auto slice_loc = get_location();
|
||||
std::shared_ptr<Expression> start, end, step;
|
||||
bool has_first_colon = false, has_second_colon = false;
|
||||
|
||||
if (!peekSymbols({ ":" })) {
|
||||
start = parseExpression();
|
||||
}
|
||||
|
||||
if (!consumeToken(":").empty()) {
|
||||
has_first_colon = true;
|
||||
if (!peekSymbols({ ":", "]" })) {
|
||||
end = parseExpression();
|
||||
}
|
||||
if (!consumeToken(":").empty()) {
|
||||
auto slice_end = parseExpression();
|
||||
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
|
||||
} else {
|
||||
auto slice_start = parseExpression();
|
||||
if (!consumeToken(":").empty()) {
|
||||
consumeSpaces();
|
||||
if (peekSymbols({ "]" })) {
|
||||
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
|
||||
} else {
|
||||
auto slice_end = parseExpression();
|
||||
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
|
||||
}
|
||||
} else {
|
||||
index = std::move(slice_start);
|
||||
has_second_colon = true;
|
||||
if (!peekSymbols({ "]" })) {
|
||||
step = parseExpression();
|
||||
}
|
||||
}
|
||||
if (!index) throw std::runtime_error("Empty index in subscript");
|
||||
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
|
||||
}
|
||||
|
||||
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
|
||||
if ((has_first_colon || has_second_colon) && (start || end || step)) {
|
||||
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
|
||||
} else {
|
||||
index = std::move(start);
|
||||
}
|
||||
if (!index) throw std::runtime_error("Empty index in subscript");
|
||||
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
|
||||
|
||||
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
|
||||
} else if (!consumeToken(".").empty()) {
|
||||
auto identifier = parseIdentifier();
|
||||
if (!identifier) throw std::runtime_error("Expected identifier in subscript");
|
||||
|
||||
204
common/regex-partial.cpp
Normal file
204
common/regex-partial.cpp
Normal file
@@ -0,0 +1,204 @@
|
||||
#include "regex-partial.h"
|
||||
#include "common.h"
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
|
||||
common_regex::common_regex(const std::string & pattern) :
|
||||
pattern(pattern),
|
||||
rx(pattern),
|
||||
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
|
||||
|
||||
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
|
||||
std::smatch match;
|
||||
if (pos > input.size()) {
|
||||
throw std::runtime_error("Position out of bounds");
|
||||
}
|
||||
auto start = input.begin() + pos;
|
||||
auto found = as_match
|
||||
? std::regex_match(start, input.end(), match, rx)
|
||||
: std::regex_search(start, input.end(), match, rx);
|
||||
if (found) {
|
||||
common_regex_match res;
|
||||
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
|
||||
for (size_t i = 0; i < match.size(); ++i) {
|
||||
auto begin = pos + match.position(i);
|
||||
res.groups.emplace_back(begin, begin + match.length(i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
std::match_results<std::string::const_reverse_iterator> srmatch;
|
||||
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
|
||||
auto group = srmatch[1].str();
|
||||
if (group.length() != 0) {
|
||||
auto it = srmatch[1].second.base();
|
||||
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
|
||||
if ((!as_match) || it == input.begin()) {
|
||||
common_regex_match res;
|
||||
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
|
||||
const size_t begin = std::distance(input.begin(), it);
|
||||
const size_t end = input.size();
|
||||
if (begin == std::string::npos || end == std::string::npos || begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
res.groups.push_back({begin, end});
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
|
||||
|
||||
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
|
||||
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
||||
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
||||
|
||||
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
|
||||
- /a|b/ -> (a|b).*
|
||||
- /a*?/ -> error, could match ""
|
||||
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
|
||||
- /.*?ab/ -> ((?:b)?a).* (merge .*)
|
||||
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
|
||||
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
|
||||
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
|
||||
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
|
||||
|
||||
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
|
||||
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
|
||||
*/
|
||||
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||
auto it = pattern.begin();
|
||||
const auto end = pattern.end();
|
||||
|
||||
std::function<std::string()> process = [&]() {
|
||||
std::vector<std::vector<std::string>> alternatives(1);
|
||||
std::vector<std::string> * sequence = &alternatives.back();
|
||||
|
||||
while (it != end) {
|
||||
if (*it == '[') {
|
||||
auto start = it;
|
||||
++it;
|
||||
while (it != end) {
|
||||
if ((*it == '\\') && (++it != end)) {
|
||||
++it;
|
||||
} else if ((it != end) && (*it == ']')) {
|
||||
break;
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
if (it == end) {
|
||||
throw std::runtime_error("Unmatched '[' in pattern");
|
||||
}
|
||||
++it;
|
||||
sequence->push_back(std::string(start, it));
|
||||
} else if (*it == '*' || *it == '?' || *it == '+') {
|
||||
if (sequence->empty()) {
|
||||
throw std::runtime_error("Quantifier without preceding element");
|
||||
}
|
||||
sequence->back() += *it;
|
||||
auto is_star = *it == '*';
|
||||
++it;
|
||||
if (is_star) {
|
||||
if (*it == '?') {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
} else if (*it == '{') {
|
||||
if (sequence->empty()) {
|
||||
throw std::runtime_error("Repetition without preceding element");
|
||||
}
|
||||
++it;
|
||||
auto start = it;
|
||||
while (it != end && *it != '}') {
|
||||
++it;
|
||||
}
|
||||
if (it == end) {
|
||||
throw std::runtime_error("Unmatched '{' in pattern");
|
||||
}
|
||||
auto parts = string_split(std::string(start, it), ",");
|
||||
++it;
|
||||
if (parts.size() > 2) {
|
||||
throw std::runtime_error("Invalid repetition range in pattern");
|
||||
}
|
||||
|
||||
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
|
||||
if (s.empty()) {
|
||||
return def;
|
||||
}
|
||||
return std::stoi(s);
|
||||
};
|
||||
auto min = parseOptInt(parts[0], 0);
|
||||
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
|
||||
if (min && max && *max < *min) {
|
||||
throw std::runtime_error("Invalid repetition range in pattern");
|
||||
}
|
||||
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
|
||||
auto part = sequence->back();
|
||||
sequence->pop_back();
|
||||
for (int i = 0; i < *min; i++) {
|
||||
sequence->push_back(part);
|
||||
}
|
||||
if (max) {
|
||||
for (int i = *min; i < *max; i++) {
|
||||
sequence->push_back(part + "?");
|
||||
}
|
||||
} else {
|
||||
sequence->push_back(part + "*");
|
||||
}
|
||||
} else if (*it == '(') {
|
||||
++it;
|
||||
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
|
||||
it += 2;
|
||||
}
|
||||
auto sub = process();
|
||||
if (*it != ')') {
|
||||
throw std::runtime_error("Unmatched '(' in pattern");
|
||||
}
|
||||
++it;
|
||||
auto & part = sequence->emplace_back("(?:");
|
||||
part += sub;
|
||||
part += ")";
|
||||
} else if (*it == ')') {
|
||||
break;
|
||||
} else if (*it == '|') {
|
||||
++it;
|
||||
alternatives.emplace_back();
|
||||
sequence = &alternatives.back();
|
||||
} else if (*it == '\\' && (++it != end)) {
|
||||
auto str = std::string("\\") + *it;
|
||||
sequence->push_back(str);
|
||||
++it;
|
||||
} else if (it != end) {
|
||||
sequence->push_back(std::string(1, *it));
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
|
||||
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
||||
// We'll do the outermost capturing group and final .* in the enclosing function.
|
||||
std::vector<std::string> res_alts;
|
||||
for (const auto & parts : alternatives) {
|
||||
auto & res = res_alts.emplace_back();
|
||||
for (size_t i = 0; i < parts.size() - 1; i++) {
|
||||
res += "(?:";
|
||||
}
|
||||
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
|
||||
res += *it;
|
||||
if (it != parts.rend() - 1) {
|
||||
res += ")?";
|
||||
}
|
||||
}
|
||||
}
|
||||
return string_join(res_alts, "|");
|
||||
};
|
||||
auto res = process();
|
||||
if (it != end) {
|
||||
throw std::runtime_error("Unmatched '(' in pattern");
|
||||
}
|
||||
|
||||
return "(" + res + ")[\\s\\S]*";
|
||||
}
|
||||
56
common/regex-partial.h
Normal file
56
common/regex-partial.h
Normal file
@@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
|
||||
#include <regex>
|
||||
#include <string>
|
||||
|
||||
enum common_regex_match_type {
|
||||
COMMON_REGEX_MATCH_TYPE_NONE,
|
||||
COMMON_REGEX_MATCH_TYPE_PARTIAL,
|
||||
COMMON_REGEX_MATCH_TYPE_FULL,
|
||||
};
|
||||
|
||||
struct common_string_range {
|
||||
size_t begin;
|
||||
size_t end;
|
||||
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
||||
if (begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
}
|
||||
// prevent default ctor
|
||||
common_string_range() = delete;
|
||||
bool empty() const {
|
||||
return begin == end;
|
||||
}
|
||||
bool operator==(const common_string_range & other) const {
|
||||
return begin == other.begin && end == other.end;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_regex_match {
|
||||
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
|
||||
std::vector<common_string_range> groups;
|
||||
|
||||
bool operator==(const common_regex_match & other) const {
|
||||
return type == other.type && groups == other.groups;
|
||||
}
|
||||
bool operator!=(const common_regex_match & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
class common_regex {
|
||||
std::string pattern;
|
||||
std::regex rx;
|
||||
std::regex rx_reversed_partial;
|
||||
|
||||
public:
|
||||
explicit common_regex(const std::string & pattern);
|
||||
|
||||
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
|
||||
|
||||
const std::string & str() const { return pattern; }
|
||||
};
|
||||
|
||||
// For testing only (pretty print of failures).
|
||||
std::string regex_to_reversed_partial_regex(const std::string & pattern);
|
||||
@@ -308,6 +308,7 @@ class ModelBase:
|
||||
gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||
gguf.MODEL_TENSOR.POSNET_NORM1,
|
||||
gguf.MODEL_TENSOR.POSNET_NORM2,
|
||||
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
|
||||
)
|
||||
)
|
||||
or not new_name.endswith(".weight")
|
||||
@@ -2069,6 +2070,9 @@ class Llama4Model(LlamaModel):
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
if name.startswith("language_model."):
|
||||
name = name.replace("language_model.", "")
|
||||
|
||||
# split the gate_up into gate and up
|
||||
if "gate_up_proj" in name:
|
||||
name_up = name.replace("gate_up_proj", "up_proj.weight")
|
||||
@@ -2089,6 +2093,26 @@ class Llama4Model(LlamaModel):
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Llama4ForConditionalGeneration")
|
||||
class Llama4VisionModel(VisionModel):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"])
|
||||
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"]))
|
||||
assert self.hparams["hidden_act"] == "gelu"
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
if "multi_modal_projector" in name or "vision_model" in name:
|
||||
# process vision tensors
|
||||
if "positional_embedding_vlm" in name and ".weight" not in name:
|
||||
name += ".weight"
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
return []
|
||||
|
||||
|
||||
@ModelBase.register("Mistral3ForConditionalGeneration")
|
||||
class Mistral3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
@@ -5746,11 +5770,20 @@ class GraniteModel(LlamaModel):
|
||||
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
|
||||
|
||||
|
||||
@ModelBase.register("GraniteMoeForCausalLM")
|
||||
@ModelBase.register("GraniteMoeForCausalLM", "GraniteMoeSharedForCausalLM")
|
||||
class GraniteMoeModel(GraniteModel):
|
||||
"""Conversion for IBM's GraniteMoeForCausalLM"""
|
||||
model_arch = gguf.MODEL_ARCH.GRANITE_MOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
"""GraniteMoeShared uses GraniteMoe parameters plus the following:
|
||||
- shared_intermediate_size
|
||||
"""
|
||||
super().set_gguf_parameters()
|
||||
if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"):
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length)
|
||||
logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
"""In modeling_granitemoe, the JetMoe implementation of parallel experts
|
||||
is used. This essentially merges w1 and w3 into a single tensor with 2x
|
||||
@@ -5761,12 +5794,21 @@ class GraniteMoeModel(GraniteModel):
|
||||
if name.endswith("block_sparse_moe.input_linear.weight"):
|
||||
ffn_dim = self.hparams["intermediate_size"]
|
||||
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
|
||||
gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
|
||||
gate, up = data_torch.split(ffn_dim, dim=-2)
|
||||
return [
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
|
||||
]
|
||||
|
||||
if name.endswith("shared_mlp.input_linear.weight"):
|
||||
ffn_dim = self.hparams["shared_intermediate_size"]
|
||||
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
|
||||
gate, up = data_torch.split(ffn_dim, dim=-2)
|
||||
return [
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate),
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up),
|
||||
]
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
|
||||
@@ -56,60 +56,82 @@ The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the abi
|
||||
|
||||
## Model Supports
|
||||
|
||||
| Model Name | FP16 | Q8_0 | Q4_0 |
|
||||
| Model Name | FP16 | Q4_0 | Q8_0 |
|
||||
|:----------------------------|:-----:|:----:|:----:|
|
||||
| AquilaChat2-7B | √ | √ | √ |
|
||||
| Baichuan-7b | √ | √ | √ |
|
||||
| Baichuan2-7B-Chat | √ | √ | √ |
|
||||
| bitnet_b1_58-large | √ | √ | √ |
|
||||
| bloom-560m | √ | x | √ |
|
||||
| bloomz-alpaca-560m | √ | x | √ |
|
||||
| c4ai-command-r-35B-v01 | x | x | x |
|
||||
| chatglm3-6B | x | x | x |
|
||||
| chinese-alpaca-2-1.3b | √ | √ | √ |
|
||||
| CodeShell-7B | √ | √ | √ |
|
||||
| deepseek-ai_deepseek-coder-1.3B-base | x | x | x |
|
||||
| deepseek-ai_DeepSeek-V2-Lite | x | x | x |
|
||||
| deepseek-coder-6.7B-instruct | x | x | x |
|
||||
| DeepSeek-V2-Lite-64x1.5B | x | x | x |
|
||||
| falcon-7b-instruct | √ | √ | √ |
|
||||
| flan-t5-large | √ | √ | √ |
|
||||
| gemma-2-9b-it | √ | √ | √ |
|
||||
| glm-4-9B | x | x | x |
|
||||
| gpt2 | √ | √ | √ |
|
||||
| Gpt2-163M | √ | √ | √ |
|
||||
| granite-3B-code-instruct | √ | √ | √ |
|
||||
| Llama-2 | √ | √ | √ |
|
||||
| Llama-3 | √ | √ | √ |
|
||||
| Mistral-7B | √ | √ | √ |
|
||||
| Mistral MOE | √ | √ | √ |
|
||||
| DBRX | - | - | - |
|
||||
| Falcon | √ | √ | √ |
|
||||
| Chinese LLaMA/Alpaca | √ | √ | √ |
|
||||
| Vigogne(French) | √ | √ | √ |
|
||||
| BERT | x | x | x |
|
||||
| Koala | √ | √ | √ |
|
||||
| Baichuan | √ | √ | √ |
|
||||
| Aquila 1 & 2 | √ | √ | √ |
|
||||
| Starcoder models | √ | √ | √ |
|
||||
| Refact | √ | √ | √ |
|
||||
| MPT | √ | √ | √ |
|
||||
| Bloom | √ | √ | √ |
|
||||
| Yi models | √ | √ | √ |
|
||||
| stablelm models | √ | √ | √ |
|
||||
| DeepSeek models | x | x | x |
|
||||
| Qwen models | √ | √ | √ |
|
||||
| PLaMo-13B | √ | √ | √ |
|
||||
| Phi models | √ | √ | √ |
|
||||
| PhiMoE | √ | √ | √ |
|
||||
| GPT-2 | √ | √ | √ |
|
||||
| Orion | √ | √ | √ |
|
||||
| InternlLM2 | √ | √ | √ |
|
||||
| CodeShell | √ | √ | √ |
|
||||
| Gemma | √ | √ | √ |
|
||||
| Mamba | √ | √ | √ |
|
||||
| Xverse | √ | √ | √ |
|
||||
| command-r models | √ | √ | √ |
|
||||
| Grok-1 | - | - | - |
|
||||
| SEA-LION | √ | √ | √ |
|
||||
| GritLM-7B | √ | √ | √ |
|
||||
| internlm2_5-7b-chat | √ | √ | √ |
|
||||
| koala-7B-HF | √ | √ | √ |
|
||||
| Llama-2-7b-chat-hf | √ | √ | √ |
|
||||
| Llama-3-Smaug-8B | √ | √ | √ |
|
||||
| Llama2-Chinese-7b-Chat | √ | √ | √ |
|
||||
| Llama3-8B | √ | √ | √ |
|
||||
| Llama3-8b-chinese | √ | √ | √ |
|
||||
| mamba-130m-hf | √ | √ | √ |
|
||||
| Mistral-7B-Instruct-v0.2 | √ | √ | √ |
|
||||
| Mixtral-8x7B-Instruct-v0.1 | x | √ | √ |
|
||||
| mpt-7B | √ | √ | √ |
|
||||
| OLMo-1B-hf | √ | √ | √ |
|
||||
| OpenELM-3B-Instruct | √ | √ | √ |
|
||||
| Orion-14b-base | √ | √ | √ |
|
||||
| phi1 | x | x | x |
|
||||
| phi2 | x | x | x |
|
||||
| Phi-3-mini-4k-instruct | √ | √ | √ |
|
||||
| plamo-13b | √ | √ | √ |
|
||||
| pythia-70M | x | x | x |
|
||||
| Qwen-7B | √ | √ | √ |
|
||||
| Qwen2-1.5B-Instruct | √ | x | √ |
|
||||
| Refact-1_6B-fim | √ | √ | √ |
|
||||
| SmolLM-135M | √ | √ | √ |
|
||||
| stablelm-zephyr | x | x | x |
|
||||
| stablelm-2-zephyr-1_6b | x | x | x |
|
||||
| starcoderbase-1b | √ | √ | √ |
|
||||
| starcoder2-3b | √ | √ | √ |
|
||||
| vigogne-7b-chat | √ | √ | √ |
|
||||
| xverse-7b-chat | √ | √ | √ |
|
||||
| Yi-6b-Chat | √ | √ | √ |
|
||||
| OLMo | √ | √ | √ |
|
||||
| OLMo 2 | √ | √ | √ |
|
||||
| OLMoE | √ | √ | √ |
|
||||
| Granite models | √ | √ | √ |
|
||||
| GPT-NeoX | √ | √ | √ |
|
||||
| Pythia | √ | √ | √ |
|
||||
| Snowflake-Arctic MoE | - | - | - |
|
||||
| Smaug | √ | √ | √ |
|
||||
| Poro 34B | √ | √ | √ |
|
||||
| Bitnet b1.58 models | √ | x | x |
|
||||
| Flan-T5 | √ | √ | √ |
|
||||
| Open Elm models | x | √ | √ |
|
||||
| chatGLM3-6B + ChatGLM4-9b + GLMEdge-1.5b + GLMEdge-4b | √ | √ | √ |
|
||||
| GLM-4-0414 | √ | √ | √ |
|
||||
| SmolLM | √ | √ | √ |
|
||||
| EXAONE-3.0-7.8B-Instruct | √ | √ | √ |
|
||||
| FalconMamba Models | √ | √ | √ |
|
||||
| Jais Models | - | x | x |
|
||||
| Bielik-11B-v2.3 | √ | √ | √ |
|
||||
| RWKV-6 | - | √ | √ |
|
||||
| QRWKV-6 | √ | √ | √ |
|
||||
| GigaChat-20B-A3B | x | x | x |
|
||||
| Trillion-7B-preview | √ | √ | √ |
|
||||
| Ling models | √ | √ | √ |
|
||||
|
||||
|
||||
**Multimodal**
|
||||
| Model Name | FP16 | Q4_0 | Q8_0 |
|
||||
|:----------------------------|:-----:|:----:|:----:|
|
||||
| LLaVA 1.5 models, LLaVA 1.6 models | x | x | x |
|
||||
| BakLLaVA | √ | √ | √ |
|
||||
| Obsidian | √ | - | - |
|
||||
| ShareGPT4V | x | - | - |
|
||||
| MobileVLM 1.7B/3B models | - | - | - |
|
||||
| Yi-VL | - | - | - |
|
||||
| Mini CPM | √ | √ | √ |
|
||||
| Moondream | √ | √ | √ |
|
||||
| Bunny | √ | - | - |
|
||||
| GLM-EDGE | √ | √ | √ |
|
||||
| Qwen2-VL | √ | √ | √ |
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -17,25 +17,25 @@
|
||||
|
||||
**SYCL** is a high-level parallel programming model designed to improve developers productivity writing code across various hardware accelerators such as CPUs, GPUs, and FPGAs. It is a single-source language designed for heterogeneous computing and based on standard C++17.
|
||||
|
||||
**oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include:
|
||||
**oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to Intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include:
|
||||
|
||||
- **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers.
|
||||
- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. Intel oneMKL, oneMath and oneDNN)*.
|
||||
- **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs.
|
||||
- **oneAPI LevelZero**: A high performance low level interface for fine-grained control over Intel iGPUs and dGPUs.
|
||||
- **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets.
|
||||
|
||||
### Llama.cpp + SYCL
|
||||
|
||||
The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it also supports other vendor GPUs: Nvidia and AMD.
|
||||
The llama.cpp SYCL backend is primarily designed for **Intel GPUs**.
|
||||
SYCL cross-platform capabilities enable support for Nvidia GPUs as well, with limited support for AMD.
|
||||
|
||||
## Recommended Release
|
||||
|
||||
The SYCL backend would be broken by some PRs due to no online CI.
|
||||
|
||||
The following release is verified with good quality:
|
||||
The following releases are verified and recommended:
|
||||
|
||||
|Commit ID|Tag|Release|Verified Platform| Update date|
|
||||
|-|-|-|-|-|
|
||||
|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |ArcB580/Linux/oneAPI 2025.1<br>LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15|
|
||||
|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19|
|
||||
|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1||
|
||||
|
||||
@@ -106,15 +106,14 @@ SYCL backend supports Intel GPU Family:
|
||||
|-------------------------------|---------|---------------------------------------|
|
||||
| Intel Data Center Max Series | Support | Max 1550, 1100 |
|
||||
| Intel Data Center Flex Series | Support | Flex 170 |
|
||||
| Intel Arc Series | Support | Arc 770, 730M, Arc A750 |
|
||||
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake |
|
||||
| Intel iGPU | Support | iGPU in 13700k,iGPU in 13400, i5-1250P, i7-1260P, i7-1165G7 |
|
||||
| Intel Arc Series | Support | Arc 770, 730M, Arc A750, B580 |
|
||||
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake, Lunar Lake |
|
||||
| Intel iGPU | Support | iGPU in 13700k, 13400, i5-1250P, i7-1260P, i7-1165G7 |
|
||||
|
||||
*Notes:*
|
||||
|
||||
- **Memory**
|
||||
- The device memory is a limitation when running a large model. The loaded model size, *`llm_load_tensors: buffer_size`*, is displayed in the log when running `./bin/llama-cli`.
|
||||
|
||||
- Please make sure the GPU shared memory from the host is large enough to account for the model's size. For e.g. the *llama-2-7b.Q4_0* requires at least 8.0GB for integrated GPU and 4.0GB for discrete GPU.
|
||||
|
||||
- **Execution Unit (EU)**
|
||||
@@ -138,9 +137,11 @@ Note: AMD GPU support is highly experimental and is incompatible with F16.
|
||||
Additionally, it only supports GPUs with a sub_group_size (warp size) of 32.
|
||||
|
||||
## Docker
|
||||
The docker build option is currently limited to *intel GPU* targets.
|
||||
|
||||
The docker build option is currently limited to *Intel GPU* targets.
|
||||
|
||||
### Build image
|
||||
|
||||
```sh
|
||||
# Using FP16
|
||||
docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile .
|
||||
@@ -148,9 +149,10 @@ docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f
|
||||
|
||||
*Notes*:
|
||||
|
||||
To build in default FP32 *(Slower than FP16 alternative)*, you can remove the `--build-arg="GGML_SYCL_F16=ON"` argument from the previous command.
|
||||
To build in default FP32 *(Slower than FP16 alternative)*, set `--build-arg="GGML_SYCL_F16=OFF"` in the previous command.
|
||||
|
||||
You can also use the `.devops/llama-server-intel.Dockerfile`, which builds the *"server"* alternative.
|
||||
Check the [documentation for Docker](../docker.md) to see the available images.
|
||||
|
||||
### Run container
|
||||
|
||||
@@ -250,7 +252,7 @@ sycl-ls
|
||||
|
||||
- **Intel GPU**
|
||||
|
||||
When targeting an intel GPU, the user should expect one or more level-zero devices among the available SYCL devices. Please make sure that at least one GPU is present, for instance [`level_zero:gpu`] in the sample output below:
|
||||
When targeting an intel GPU, the user should expect one or more devices among the available SYCL devices. Please make sure that at least one GPU is present via `sycl-ls`, for instance `[level_zero:gpu]` in the sample output below:
|
||||
|
||||
```
|
||||
[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000]
|
||||
@@ -282,7 +284,7 @@ For AMD GPUs we should expect at least one SYCL-HIP device [`hip:gpu`]:
|
||||
|
||||
#### Intel GPU
|
||||
|
||||
```
|
||||
```sh
|
||||
./examples/sycl/build.sh
|
||||
```
|
||||
|
||||
@@ -351,7 +353,7 @@ cmake --build build --config Release -j -v
|
||||
|
||||
#### Retrieve and prepare model
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model prepration, or simply download [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) model as example.
|
||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
|
||||
##### Check device
|
||||
|
||||
@@ -398,11 +400,15 @@ Choose one of following methods to run.
|
||||
|
||||
```sh
|
||||
./examples/sycl/run-llama2.sh 0
|
||||
# OR
|
||||
./examples/sycl/run-llama3.sh 0
|
||||
```
|
||||
- Use multiple devices:
|
||||
|
||||
```sh
|
||||
./examples/sycl/run-llama2.sh
|
||||
# OR
|
||||
./examples/sycl/run-llama3.sh
|
||||
```
|
||||
|
||||
2. Command line
|
||||
@@ -425,13 +431,13 @@ Examples:
|
||||
- Use device 0:
|
||||
|
||||
```sh
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm none -mg 0
|
||||
```
|
||||
|
||||
- Use multiple devices:
|
||||
|
||||
```sh
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm layer
|
||||
```
|
||||
|
||||
*Notes:*
|
||||
@@ -452,7 +458,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
|
||||
1. Install GPU driver
|
||||
|
||||
Intel GPU drivers instructions guide and download page can be found here: [Get intel GPU Drivers](https://www.intel.com/content/www/us/en/products/docs/discrete-gpus/arc/software/drivers.html).
|
||||
Intel GPU drivers instructions guide and download page can be found here: [Get Intel GPU Drivers](https://www.intel.com/content/www/us/en/products/docs/discrete-gpus/arc/software/drivers.html).
|
||||
|
||||
2. Install Visual Studio
|
||||
|
||||
@@ -629,7 +635,7 @@ Once it is completed, final results will be in **build/Release/bin**
|
||||
|
||||
#### Retrieve and prepare model
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model prepration, or simply download [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) model as example.
|
||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
|
||||
##### Check device
|
||||
|
||||
@@ -648,7 +654,7 @@ Similar to the native `sycl-ls`, available SYCL devices can be queried as follow
|
||||
build\bin\llama-ls-sycl-device.exe
|
||||
```
|
||||
|
||||
This command will only display the selected backend that is supported by SYCL. The default backend is level_zero. For example, in a system with 2 *intel GPU* it would look like the following:
|
||||
This command will only display the selected backend that is supported by SYCL. The default backend is level_zero. For example, in a system with 2 *Intel GPU* it would look like the following:
|
||||
```
|
||||
found 2 SYCL devices:
|
||||
| | | |Compute |Max compute|Max work|Max sub| |
|
||||
@@ -658,13 +664,14 @@ found 2 SYCL devices:
|
||||
| 1|[level_zero:gpu:1]| Intel(R) UHD Graphics 770| 1.3| 32| 512| 32| 53651849216|
|
||||
|
||||
```
|
||||
|
||||
#### Choose level-zero devices
|
||||
|
||||
|Chosen Device ID|Setting|
|
||||
|-|-|
|
||||
|0|`set ONEAPI_DEVICE_SELECTOR="level_zero:1"` or no action|
|
||||
|0|Default option. You may also want to `set ONEAPI_DEVICE_SELECTOR="level_zero:0"`|
|
||||
|1|`set ONEAPI_DEVICE_SELECTOR="level_zero:1"`|
|
||||
|0 & 1|`set ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"`|
|
||||
|0 & 1|`set ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"` or `set ONEAPI_DEVICE_SELECTOR="level_zero:*"`|
|
||||
|
||||
#### Execute
|
||||
|
||||
@@ -673,7 +680,13 @@ Choose one of following methods to run.
|
||||
1. Script
|
||||
|
||||
```
|
||||
examples\sycl\win-run-llama2.bat
|
||||
examples\sycl\win-run-llama-2.bat
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
examples\sycl\win-run-llama-3.bat
|
||||
```
|
||||
|
||||
2. Command line
|
||||
@@ -697,13 +710,13 @@ Examples:
|
||||
- Use device 0:
|
||||
|
||||
```
|
||||
build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm none -mg 0
|
||||
build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm none -mg 0
|
||||
```
|
||||
|
||||
- Use multiple devices:
|
||||
|
||||
```
|
||||
build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm layer
|
||||
build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm layer
|
||||
```
|
||||
|
||||
|
||||
@@ -714,7 +727,9 @@ Note:
|
||||
```sh
|
||||
detect 1 SYCL GPUs: [0] with top Max compute units:512
|
||||
```
|
||||
|
||||
Or
|
||||
|
||||
```sh
|
||||
use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
```
|
||||
@@ -726,14 +741,17 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
|
||||
| Name | Value | Function |
|
||||
|--------------------|---------------------------------------|---------------------------------------------|
|
||||
| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path.<br>FP32 path - recommended for better perforemance than FP16 on quantized model|
|
||||
| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path. |
|
||||
| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. |
|
||||
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
|
||||
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
|
||||
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. (1.) |
|
||||
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
|
||||
| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. |
|
||||
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
|
||||
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
|
||||
|
||||
1. FP16 is recommended for better prompt processing performance on quantized models. Performance is equivalent in text generation but set `GGML_SYCL_F16=OFF` if you are experiencing issues with FP16 builds.
|
||||
|
||||
#### Runtime
|
||||
|
||||
| Name | Value | Function |
|
||||
@@ -741,6 +759,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
|
||||
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
|
||||
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
|
||||
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
|
||||
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
|
||||
|
||||
|
||||
@@ -750,7 +769,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
|
||||
## Q&A
|
||||
|
||||
- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`.
|
||||
- Error: `error while loading shared libraries: libsycl.so: cannot open shared object file: No such file or directory`.
|
||||
|
||||
- Potential cause: Unavailable oneAPI installation or not set ENV variables.
|
||||
- Solution: Install *oneAPI base toolkit* and enable its ENV through: `source /opt/intel/oneapi/setvars.sh`.
|
||||
@@ -779,18 +798,18 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
|
||||
It's same for other projects including llama.cpp SYCL backend.
|
||||
|
||||
- Meet issue: `Native API failed. Native API returns: -6 (PI_ERROR_OUT_OF_HOST_MEMORY) -6 (PI_ERROR_OUT_OF_HOST_MEMORY) -999 (UNKNOWN PI error)` or `failed to allocate SYCL0 buffer`
|
||||
- `Native API failed. Native API returns: 39 (UR_RESULT_ERROR_OUT_OF_DEVICE_MEMORY)`, `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 3503030272 Bytes of memory on device`, or `failed to allocate SYCL0 buffer`
|
||||
|
||||
Device Memory is not enough.
|
||||
You are running out of Device Memory.
|
||||
|
||||
|Reason|Solution|
|
||||
|-|-|
|
||||
|Default Context is too big. It leads to more memory usage.|Set `-c 8192` or smaller value.|
|
||||
|Model is big and require more memory than device's.|Choose smaller quantized model, like Q5 -> Q4;<br>Use more than one devices to load model.|
|
||||
| The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.|
|
||||
| The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;<br>Alternatively, use more than one device to load model.|
|
||||
|
||||
### **GitHub contribution**:
|
||||
Please add the **[SYCL]** prefix/tag in issues/PRs titles to help the SYCL-team check/address them without delay.
|
||||
Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.
|
||||
|
||||
## TODO
|
||||
|
||||
- NA
|
||||
- Review ZES_ENABLE_SYSMAN: https://github.com/intel/compute-runtime/blob/master/programmers-guide/SYSMAN.md#support-and-limitations
|
||||
|
||||
@@ -22,6 +22,9 @@ Additionally, there the following images, similar to the above:
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-musa`: Same as `full` but compiled with MUSA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-musa`: Same as `light` but compiled with MUSA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-musa`: Same as `server` but compiled with MUSA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:full-intel`: Same as `full` but compiled with SYCL support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:light-intel`: Same as `light` but compiled with SYCL support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggml-org/llama.cpp:server-intel`: Same as `server` but compiled with SYCL support. (platforms: `linux/amd64`)
|
||||
|
||||
The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now).
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload
|
||||
|
||||
## Pre-quantized models
|
||||
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default.
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/ggml-org
|
||||
|
||||
Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server`
|
||||
|
||||
@@ -74,4 +74,7 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
(tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/InternVL3-8B-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF
|
||||
|
||||
# Llama 4 Scout
|
||||
(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF
|
||||
```
|
||||
|
||||
@@ -32,6 +32,7 @@ else()
|
||||
add_subdirectory(speculative)
|
||||
add_subdirectory(speculative-simple)
|
||||
add_subdirectory(gen-docs)
|
||||
add_subdirectory(training)
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
add_subdirectory(convert-llama2c-to-ggml)
|
||||
# these examples use the backends directly and cannot be built with dynamic loading
|
||||
|
||||
@@ -50,8 +50,6 @@ int main(int argc, char ** argv) {
|
||||
const int N = 5; // n-gram size
|
||||
const int G = 15; // max verification n-grams
|
||||
|
||||
const bool dump_kv_cache = params.dump_kv_cache;
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
@@ -152,9 +150,6 @@ int main(int argc, char ** argv) {
|
||||
// here we keep adding new n-grams as we go
|
||||
ngram_container ngrams_observed(llama_vocab_n_tokens(vocab), N, G);
|
||||
|
||||
// debug
|
||||
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);
|
||||
|
||||
const auto t_dec_start = ggml_time_us();
|
||||
|
||||
// sample first token
|
||||
@@ -172,12 +167,6 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
while (true) {
|
||||
// debug
|
||||
if (dump_kv_cache) {
|
||||
llama_kv_cache_view_update(ctx, &kvc_view);
|
||||
common_kv_cache_dump_view_seqs(kvc_view, 40);
|
||||
}
|
||||
|
||||
// build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/
|
||||
//
|
||||
// Example for W = 5, N = 4, G = 2:
|
||||
@@ -473,8 +462,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_sampler_free(smpl);
|
||||
|
||||
llama_kv_cache_view_free(&kvc_view);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
@@ -24,8 +24,6 @@ int main(int argc, char ** argv){
|
||||
// max. number of additional tokens to draft if match is found
|
||||
const int n_draft = params.speculative.n_max;
|
||||
|
||||
const bool dump_kv_cache = params.dump_kv_cache;
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
@@ -110,18 +108,9 @@ int main(int argc, char ** argv){
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
|
||||
|
||||
// debug
|
||||
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
|
||||
|
||||
const auto t_dec_start = ggml_time_us();
|
||||
|
||||
while (true) {
|
||||
// debug
|
||||
if (dump_kv_cache) {
|
||||
llama_kv_cache_view_update(ctx, &kvc_view);
|
||||
common_kv_cache_dump_view_seqs(kvc_view, 40);
|
||||
}
|
||||
|
||||
// print current draft sequence
|
||||
LOG_DBG("drafted %s\n", string_from(ctx, draft).c_str());
|
||||
|
||||
|
||||
@@ -1,3 +1,14 @@
|
||||
# llama.cpp/example/parallel
|
||||
|
||||
Simplified simulation of serving incoming requests in parallel
|
||||
|
||||
## Example
|
||||
|
||||
Generate 128 client requests (`-ns 128`), simulating 8 concurrent clients (`-np 8`). The system prompt is shared (`-pps`), meaning that it is computed once at the start. The client requests consist of 10 junk questions (`-j 10`) followed by the actual question.
|
||||
|
||||
```bash
|
||||
llama-parallel -m model.gguf -np 8 -ns 128 --top-k 1 -pps --junk 10 -c 16384
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> It's recommended to use base models with this example. Instruction tuned models might not be able to properly follow the custom chat template specified here, so the results might not be as expected.
|
||||
|
||||
@@ -34,11 +34,61 @@ static std::string k_system =
|
||||
R"(Transcript of a never ending dialog, where the User interacts with an Assistant.
|
||||
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
|
||||
|
||||
User: Recommend a nice restaurant in the area.
|
||||
Assistant: I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
|
||||
User: Who is Richard Feynman?
|
||||
Assistant: Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
|
||||
User:)";
|
||||
User:
|
||||
Recommend a nice restaurant in the area.
|
||||
Assistant:
|
||||
I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
|
||||
User:
|
||||
Who is Richard Feynman?
|
||||
Assistant:
|
||||
Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
|
||||
)";
|
||||
|
||||
static std::vector<std::string> k_questions = {
|
||||
"What is the tallest mountain in the world?",
|
||||
"Who was the first person to win two Nobel Prizes?",
|
||||
"Which country invented paper?",
|
||||
"What organ is primarily responsible for pumping blood throughout the body?",
|
||||
"Which planet is known for its prominent ring system?",
|
||||
"Who directed the movie 'Inception'?",
|
||||
"What is the freezing point of water in Fahrenheit?",
|
||||
"Which animal is known to have the longest lifespan?",
|
||||
"What language has the most native speakers worldwide?",
|
||||
"What is the capital city of Canada?",
|
||||
"Who is credited with inventing the World Wide Web?",
|
||||
"Which metal is liquid at room temperature?",
|
||||
"What is the term for an animal that eats both plants and meat?",
|
||||
"Who painted 'The Starry Night'?",
|
||||
"What gas do humans exhale that plants use for photosynthesis?",
|
||||
"What year did World War II end?",
|
||||
"Which continent has the most countries?",
|
||||
"Who wrote the novel 'Frankenstein'?",
|
||||
"What does DNA stand for?",
|
||||
"What is the main ingredient in traditional Japanese miso soup?"
|
||||
};
|
||||
|
||||
static std::vector<std::string> k_answers = {
|
||||
"The tallest mountain in the world is Mount Everest.",
|
||||
"Marie Curie was the first person to win two Nobel Prizes.",
|
||||
"Paper was invented in China.",
|
||||
"The heart is the organ responsible for pumping blood.",
|
||||
"Saturn is known for its prominent ring system.",
|
||||
"Christopher Nolan directed the movie 'Inception'.",
|
||||
"The freezing point of water in Fahrenheit is 32°F.",
|
||||
"The bowhead whale is known to have the longest lifespan among mammals.",
|
||||
"Mandarin Chinese has the most native speakers in the world.",
|
||||
"The capital city of Canada is Ottawa.",
|
||||
"Tim Berners-Lee is credited with inventing the World Wide Web.",
|
||||
"Mercury is the metal that is liquid at room temperature.",
|
||||
"An animal that eats both plants and meat is called an omnivore.",
|
||||
"'The Starry Night' was painted by Vincent van Gogh.",
|
||||
"Humans exhale carbon dioxide, which plants use in photosynthesis.",
|
||||
"World War II ended in 1945.",
|
||||
"Africa is the continent with the most countries.",
|
||||
"The novel 'Frankenstein' was written by Mary Shelley.",
|
||||
"DNA stands for Deoxyribonucleic Acid.",
|
||||
"The main ingredient in traditional Japanese miso soup is fermented soybean paste."
|
||||
};
|
||||
|
||||
static std::vector<std::string> k_prompts = {
|
||||
"What is the meaning of life?",
|
||||
@@ -49,7 +99,7 @@ static std::vector<std::string> k_prompts = {
|
||||
"What is the best way to learn a new language?",
|
||||
"How to get a job at Google?",
|
||||
"If you could have any superpower, what would it be?",
|
||||
"I want to learn how to play the piano.",
|
||||
"I want to learn how to play the piano. What would be the best way to do it?",
|
||||
};
|
||||
|
||||
struct client {
|
||||
@@ -68,6 +118,7 @@ struct client {
|
||||
int64_t t_start_prompt;
|
||||
int64_t t_start_gen;
|
||||
|
||||
int32_t n_past = 0;
|
||||
int32_t n_prompt = 0;
|
||||
int32_t n_decoded = 0;
|
||||
int32_t i_batch = -1;
|
||||
@@ -107,6 +158,7 @@ int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.n_predict = 128;
|
||||
params.n_junk = 0;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
|
||||
return 1;
|
||||
@@ -126,7 +178,11 @@ int main(int argc, char ** argv) {
|
||||
// insert new requests as soon as the previous one is done
|
||||
const bool cont_batching = params.cont_batching;
|
||||
|
||||
const bool dump_kv_cache = params.dump_kv_cache;
|
||||
// is the system prompt shared in the cache
|
||||
const bool is_sp_shared = params.is_pp_shared;
|
||||
|
||||
// extra text to insert in each client's prompt in order to make it larger
|
||||
const int32_t n_junk = params.n_junk;
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init();
|
||||
@@ -169,6 +225,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
|
||||
tokens_system = common_tokenize(ctx, k_system, true);
|
||||
const int32_t n_tokens_system = tokens_system.size();
|
||||
|
||||
@@ -182,15 +239,13 @@ int main(int argc, char ** argv) {
|
||||
int32_t n_total_gen = 0;
|
||||
int32_t n_cache_miss = 0;
|
||||
|
||||
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients);
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
LOG_INF("%s: Simulating parallel requests from clients:\n", __func__);
|
||||
LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
|
||||
LOG_INF("\n");
|
||||
|
||||
{
|
||||
if (is_sp_shared) {
|
||||
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
|
||||
|
||||
for (int32_t i = 0; i < n_tokens_system; ++i) {
|
||||
@@ -213,11 +268,6 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("Processing requests ...\n\n");
|
||||
|
||||
while (true) {
|
||||
if (dump_kv_cache) {
|
||||
llama_kv_cache_view_update(ctx, &kvc_view);
|
||||
common_kv_cache_dump_view_seqs(kvc_view, 40);
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
// decode any currently ongoing sequences
|
||||
@@ -228,7 +278,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
client.i_batch = batch.n_tokens;
|
||||
|
||||
common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
|
||||
common_batch_add(batch, client.sampled, client.n_past++, { client.id + 1 }, true);
|
||||
|
||||
client.n_decoded += 1;
|
||||
}
|
||||
@@ -254,9 +304,23 @@ int main(int argc, char ** argv) {
|
||||
client.t_start_gen = 0;
|
||||
|
||||
client.input = k_prompts[rand() % k_prompts.size()];
|
||||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
|
||||
// construct the prompt:
|
||||
// [system prompt] + [junk] + [user prompt]
|
||||
client.n_past = 0;
|
||||
client.prompt = "";
|
||||
if (is_sp_shared) {
|
||||
client.n_past = n_tokens_system;
|
||||
} else {
|
||||
client.prompt += k_system;
|
||||
}
|
||||
for (int i = 0; i < n_junk; ++i) {
|
||||
const int r = rand() % k_questions.size();
|
||||
client.prompt += "User:\n" + k_questions[r] + "\nAssistant:\n " + k_answers[r] + "\n";
|
||||
}
|
||||
client.prompt += "User:\n" + client.input + "\nAssistant:\n";
|
||||
|
||||
common_sampler_reset(client.smpl);
|
||||
|
||||
// do not prepend BOS because we have a system prompt!
|
||||
@@ -264,7 +328,7 @@ int main(int argc, char ** argv) {
|
||||
tokens_prompt = common_tokenize(ctx, client.prompt, false);
|
||||
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
|
||||
common_batch_add(batch, tokens_prompt[i], client.n_past++, { client.id + 1 }, false);
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
@@ -363,10 +427,9 @@ int main(int argc, char ** argv) {
|
||||
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
|
||||
|
||||
if (client.n_decoded > 2 &&
|
||||
(llama_vocab_is_eog(vocab, id) ||
|
||||
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
|
||||
client.response.find("User:") != std::string::npos ||
|
||||
client.response.find('\n') != std::string::npos)) {
|
||||
(llama_vocab_is_eog(vocab, id) ||
|
||||
(params.n_predict > 0 && client.n_decoded >= params.n_predict) ||
|
||||
client.response.find("User:") != std::string::npos)) {
|
||||
// basic reverse prompt
|
||||
const size_t pos = client.response.find("User:");
|
||||
if (pos != std::string::npos) {
|
||||
|
||||
@@ -84,13 +84,13 @@ int main(int argc, char ** argv) {
|
||||
model_params.n_gpu_layers = ngl;
|
||||
|
||||
llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (model == NULL) {
|
||||
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
// tokenize the prompt
|
||||
|
||||
// find the number of tokens in the prompt
|
||||
|
||||
@@ -12,16 +12,16 @@ source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
|
||||
MODEL_FILE=models/llama-2-7b.Q4_0.gguf
|
||||
NGL=33
|
||||
CONEXT=4096
|
||||
NGL=99
|
||||
CONTEXT=4096
|
||||
|
||||
if [ $# -gt 0 ]; then
|
||||
GGML_SYCL_DEVICE=$1
|
||||
echo "use $GGML_SYCL_DEVICE as main GPU"
|
||||
#use signle GPU only
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT} -mg $GGML_SYCL_DEVICE -sm none
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
|
||||
|
||||
else
|
||||
#use multiple GPUs with same max compute units
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT}
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
|
||||
fi
|
||||
|
||||
28
examples/sycl/run-llama3.sh
Executable file
28
examples/sycl/run-llama3.sh
Executable file
@@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MIT license
|
||||
# Copyright (C) 2025 Intel Corporation
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# If you want more control, DPC++ Allows selecting a specific device through the
|
||||
# following environment variable
|
||||
#export ONEAPI_DEVICE_SELECTOR="level_zero:0"
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
#export GGML_SYCL_DEBUG=1
|
||||
|
||||
#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer.
|
||||
|
||||
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
|
||||
MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
|
||||
NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using.
|
||||
CONTEXT=4096
|
||||
|
||||
if [ $# -gt 0 ]; then
|
||||
GGML_SYCL_DEVICE=$1
|
||||
echo "Using $GGML_SYCL_DEVICE as the main GPU"
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
|
||||
else
|
||||
#use multiple GPUs with same max compute units
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT}
|
||||
fi
|
||||
@@ -6,4 +6,4 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
|
||||
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
|
||||
|
||||
|
||||
.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 33 -s 0
|
||||
.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0
|
||||
|
||||
9
examples/sycl/win-run-llama3.bat
Normal file
9
examples/sycl/win-run-llama3.bat
Normal file
@@ -0,0 +1,9 @@
|
||||
:: MIT license
|
||||
:: Copyright (C) 2024 Intel Corporation
|
||||
:: SPDX-License-Identifier: MIT
|
||||
|
||||
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
|
||||
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
|
||||
|
||||
|
||||
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99
|
||||
5
examples/training/CMakeLists.txt
Normal file
5
examples/training/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
set(TARGET llama-finetune)
|
||||
add_executable(${TARGET} finetune.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
17
examples/training/README.md
Normal file
17
examples/training/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# llama.cpp/examples/training
|
||||
|
||||
This directory contains examples related to language model training using llama.cpp/GGML.
|
||||
So far finetuning is technically functional (for FP32 models and limited hardware setups) but the code is very much WIP.
|
||||
Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory.
|
||||
**For CPU training, compile llama.cpp without any additional backends such as CUDA.**
|
||||
**For CUDA training, use the maximum number of GPU layers.**
|
||||
|
||||
Proof of concept:
|
||||
|
||||
``` sh
|
||||
export model_name=llama_3.2-1b && export quantization=f32
|
||||
./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
|
||||
./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
|
||||
```
|
||||
|
||||
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.
|
||||
96
examples/training/finetune.cpp
Normal file
96
examples/training/finetune.cpp
Normal file
@@ -0,0 +1,96 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <vector>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.escape = false;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.use_mmap) {
|
||||
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
|
||||
params.use_mmap = false;
|
||||
}
|
||||
if (params.cache_type_k != GGML_TYPE_F32) {
|
||||
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
|
||||
params.cache_type_k = GGML_TYPE_F32;
|
||||
}
|
||||
if (params.cache_type_v != GGML_TYPE_F32) {
|
||||
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
|
||||
params.cache_type_v = GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
common_init();
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// load the model and apply lora adapter, if any
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
llama_model_ptr & model = llama_init.model;
|
||||
llama_context_ptr & ctx = llama_init.context;
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: unable to load model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
LOG_INF("\n");
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
constexpr float val_split = 0.05f;
|
||||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
|
||||
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
|
||||
|
||||
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
|
||||
optimizer_params.adamw.alpha = 1e-7f; // learning rate
|
||||
|
||||
struct llama_opt_params lopt_params {
|
||||
/*n_ctx_train =*/ 0,
|
||||
/*param_filter =*/ llama_opt_param_filter_all,
|
||||
/*param_filter_ud =*/ nullptr,
|
||||
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
|
||||
/*get_opt_pars_ud =*/ &optimizer_params,
|
||||
};
|
||||
llama_opt_init(ctx.get(), model.get(), lopt_params);
|
||||
|
||||
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
|
||||
|
||||
ggml_opt_result_t result_train = ggml_opt_result_init();
|
||||
ggml_opt_result_t result_eval = ggml_opt_result_init();
|
||||
|
||||
for (int epoch = 0; epoch < 2; ++epoch) {
|
||||
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
|
||||
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
ggml_opt_result_reset(result_train);
|
||||
ggml_opt_result_reset(result_eval);
|
||||
}
|
||||
ggml_opt_result_free(result_train);
|
||||
ggml_opt_result_free(result_eval);
|
||||
|
||||
llama_model_save_to_file(model.get(), "finetuned-model.gguf");
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
|
||||
option(GGML_SYCL "ggml: use SYCL" OFF)
|
||||
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
||||
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
|
||||
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
|
||||
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
||||
"ggml: sycl target device")
|
||||
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
|
||||
|
||||
@@ -37,13 +37,16 @@ extern "C" {
|
||||
// ====== Dataset ======
|
||||
|
||||
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
|
||||
int64_t ne_datapoint, // number of elements per datapoint
|
||||
int64_t ne_label, // number of elements per label
|
||||
int64_t ndata, // total number of datapoints/labels
|
||||
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
|
||||
enum ggml_type type_data, // the type for the internal data tensor
|
||||
enum ggml_type type_label, // the type for the internal labels tensor
|
||||
int64_t ne_datapoint, // number of elements per datapoint
|
||||
int64_t ne_label, // number of elements per label
|
||||
int64_t ndata, // total number of datapoints/labels
|
||||
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
|
||||
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
|
||||
|
||||
// get underlying tensors that store the data
|
||||
GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
|
||||
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
|
||||
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
|
||||
|
||||
@@ -56,13 +59,19 @@ extern "C" {
|
||||
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
|
||||
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
|
||||
int64_t ibatch);
|
||||
GGML_API void ggml_opt_dataset_get_batch_host(
|
||||
ggml_opt_dataset_t dataset,
|
||||
void * data_batch,
|
||||
size_t nb_data_batch,
|
||||
void * labels_batch,
|
||||
int64_t ibatch);
|
||||
|
||||
// ====== Model / Context ======
|
||||
|
||||
enum ggml_opt_build_type {
|
||||
GGML_OPT_BUILD_TYPE_FORWARD,
|
||||
GGML_OPT_BUILD_TYPE_GRAD,
|
||||
GGML_OPT_BUILD_TYPE_OPT,
|
||||
GGML_OPT_BUILD_TYPE_FORWARD = 10,
|
||||
GGML_OPT_BUILD_TYPE_GRAD = 20,
|
||||
GGML_OPT_BUILD_TYPE_OPT = 30,
|
||||
};
|
||||
|
||||
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
||||
@@ -81,20 +90,22 @@ extern "C" {
|
||||
// userdata can be used to pass arbitrary data
|
||||
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
|
||||
|
||||
// returns the default optimizer params (constant)
|
||||
// returns the default optimizer params (constant, hard-coded values)
|
||||
// userdata is not used
|
||||
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
|
||||
|
||||
// casts userdata to ggml_opt_optimizer_params and returns it
|
||||
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
|
||||
|
||||
// parameters for initializing a new optimization context
|
||||
struct ggml_opt_params {
|
||||
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
|
||||
|
||||
struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
|
||||
|
||||
// the forward graph is defined by inputs and outputs
|
||||
// those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
|
||||
struct ggml_tensor * inputs;
|
||||
struct ggml_tensor * outputs;
|
||||
// by default the forward graph needs to be reconstructed for each eval
|
||||
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
|
||||
struct ggml_context * ctx_compute;
|
||||
struct ggml_tensor * inputs;
|
||||
struct ggml_tensor * outputs;
|
||||
|
||||
enum ggml_opt_loss_type loss_type;
|
||||
enum ggml_opt_build_type build_type;
|
||||
@@ -107,12 +118,9 @@ extern "C" {
|
||||
|
||||
// get parameters for an optimization context with defaults set where possible
|
||||
// parameters for which no sensible defaults exist are supplied as arguments to this function
|
||||
GGML_API ggml_opt_params ggml_opt_default_params(
|
||||
ggml_backend_sched_t backend_sched,
|
||||
struct ggml_context * ctx_compute,
|
||||
struct ggml_tensor * inputs,
|
||||
struct ggml_tensor * outputs,
|
||||
enum ggml_opt_loss_type loss_type);
|
||||
GGML_API struct ggml_opt_params ggml_opt_default_params(
|
||||
ggml_backend_sched_t backend_sched,
|
||||
enum ggml_opt_loss_type loss_type);
|
||||
|
||||
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
|
||||
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
|
||||
@@ -120,7 +128,10 @@ extern "C" {
|
||||
// set gradients to zero, initilize loss, and optionally reset the optimizer
|
||||
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
|
||||
|
||||
GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
|
||||
|
||||
// get underlying tensors that store data
|
||||
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
|
||||
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
|
||||
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
|
||||
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
|
||||
@@ -128,11 +139,12 @@ extern "C" {
|
||||
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
|
||||
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
|
||||
|
||||
// get the gradient accumulator for a node from the forward graph
|
||||
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
|
||||
|
||||
// ====== Optimization Result ======
|
||||
|
||||
GGML_API ggml_opt_result_t ggml_opt_result_init();
|
||||
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
|
||||
GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
|
||||
GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
|
||||
|
||||
@@ -144,11 +156,20 @@ extern "C" {
|
||||
|
||||
// ====== Computation ======
|
||||
|
||||
// do forward pass, increment result if not NULL
|
||||
GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
|
||||
// if not using static graphs, this function must be called prior to ggml_opt_alloc
|
||||
GGML_API void ggml_opt_prepare_alloc(
|
||||
ggml_opt_context_t opt_ctx,
|
||||
struct ggml_context * ctx_compute,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * inputs,
|
||||
struct ggml_tensor * outputs);
|
||||
|
||||
// do forward pass, increment result if not NULL, do backward pass
|
||||
GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
|
||||
// allocate the next graph for evaluation, either forward or forward + backward
|
||||
// must be called exactly once prior to calling ggml_opt_eval
|
||||
GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
|
||||
|
||||
// do forward pass, increment result if not NULL, do backward pass if allocated
|
||||
GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
|
||||
|
||||
// ############################################################################
|
||||
// ## The high-level functions start here. They do not depend on any private ##
|
||||
@@ -200,9 +221,9 @@ extern "C" {
|
||||
// fit model defined by inputs and outputs to dataset
|
||||
GGML_API void ggml_opt_fit(
|
||||
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
|
||||
ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
|
||||
ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
|
||||
ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
||||
struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
|
||||
struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
|
||||
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
||||
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
||||
enum ggml_opt_loss_type loss_type, // loss to minimize
|
||||
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
||||
|
||||
@@ -768,7 +768,7 @@ extern "C" {
|
||||
// Tensor flags
|
||||
GGML_API void ggml_set_input(struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_set_output(struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_set_param(struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
|
||||
|
||||
//
|
||||
@@ -938,7 +938,7 @@ extern "C" {
|
||||
GGML_API struct ggml_tensor * ggml_repeat_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
|
||||
|
||||
// concat a and b along dim
|
||||
// used in stable-diffusion
|
||||
@@ -2049,15 +2049,14 @@ extern "C" {
|
||||
|
||||
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_build_backward_expand(
|
||||
struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
|
||||
struct ggml_context * ctx_compute, // context for gradient computation
|
||||
struct ggml_cgraph * cgraph,
|
||||
bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
|
||||
struct ggml_context * ctx, // context for gradient computation
|
||||
struct ggml_cgraph * cgraph,
|
||||
struct ggml_tensor ** grad_accs);
|
||||
|
||||
// graph allocation in a context
|
||||
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
|
||||
GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
|
||||
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
|
||||
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);
|
||||
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
|
||||
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
|
||||
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
|
||||
|
||||
@@ -1111,7 +1111,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
|
||||
const int node_backend_id = tensor_backend_id(node);
|
||||
|
||||
assert(node_backend_id != -1); // all nodes should be assigned by now
|
||||
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
|
||||
|
||||
// check if we should start a new split based on the sources of the current node
|
||||
bool need_new_split = false;
|
||||
|
||||
@@ -65,6 +65,7 @@
|
||||
#include <aclnnop/aclnn_eq_tensor.h>
|
||||
#include <aclnnop/aclnn_gt_scalar.h>
|
||||
#include <aclnnop/aclnn_pow.h>
|
||||
#include <aclnnop/aclnn_grouped_matmul_v2.h>
|
||||
#include <float.h>
|
||||
|
||||
#include <cmath>
|
||||
@@ -2587,3 +2588,149 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs expert-specific matrix multiplication (MoE) with
|
||||
* floating-point precision using the CANN backend.
|
||||
*
|
||||
* This function executes a matrix multiplication operation tailored for
|
||||
* Mixture of Experts (MoE) models, where the input tensor is multiplied
|
||||
* with expert-specific weight matrices. It uses the CANN backend for
|
||||
* efficient computation and stores the result in the destination tensor `dst`.
|
||||
* The operation may leverage identity-based optimizations or routing masks
|
||||
* as part of sparse expert selection.
|
||||
*
|
||||
* @param ctx The context for executing CANN backend operations.
|
||||
* @param dst The destination tensor where the MoE multiplication result
|
||||
* will be stored.
|
||||
*
|
||||
* @note This function assumes floating-point data types and is designed for
|
||||
* MoE architectures, possibly involving sparse expert routing.
|
||||
*/
|
||||
static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
//dst [M, K, N, 1]
|
||||
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
|
||||
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
|
||||
ggml_tensor * ids = dst->src[2]; //ids [K, N]
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
// copy index from npu to cpu
|
||||
int64_t n_as = ne02; // A
|
||||
int64_t n_ids = ids->ne[0]; // K
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
|
||||
ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
|
||||
|
||||
char * src0_original = (char *) src0->data;
|
||||
char * src1_original = (char *) src1->data;
|
||||
char * dst_original = (char *) dst->data;
|
||||
size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03};
|
||||
|
||||
// src0 is F16, src1 is F32, dst is F32
|
||||
ggml_cann_pool_alloc src0_cast_allocator;
|
||||
if (src0->type == GGML_TYPE_F16) {
|
||||
src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0));
|
||||
void* src0_cast_buf = src0_cast_allocator.get();
|
||||
|
||||
size_t cast_nb[GGML_MAX_DIMS];
|
||||
cast_nb[0] = sizeof(float_t);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
|
||||
aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf,
|
||||
ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast);
|
||||
ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16);
|
||||
|
||||
src0_original = (char *) src0_cast_buf;
|
||||
memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb));
|
||||
}
|
||||
|
||||
std::vector<aclTensor*> src0_tensor_vec;
|
||||
std::vector<aclTensor*> src1_tensor_vec;
|
||||
std::vector<aclTensor*> dst_tensor_vec;
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
// src0_row [M, D] -> weight && permute
|
||||
int64_t src0_ne[2] = {ne01, ne00};
|
||||
size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]};
|
||||
// src1_row [D, 1] -> input
|
||||
int64_t src1_ne[2] = {ne10, 1};
|
||||
size_t src1_nb[2] = {nb10, nb11};
|
||||
// dst_row [M, 1] -> out
|
||||
int64_t dst_ne[2] = {ne0, 1};
|
||||
size_t dst_nb[2] = {nb0, nb1};
|
||||
|
||||
// expert index
|
||||
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
||||
|
||||
// If B = 1 (broadcast), always use 0; otherwise, use id.
|
||||
int64_t i11 = (ne11 == 1 ? 0 : id);
|
||||
int64_t i12 = iid1;
|
||||
|
||||
int64_t i1 = id;
|
||||
int64_t i2 = i12;
|
||||
|
||||
void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
|
||||
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
|
||||
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
|
||||
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr,
|
||||
ACL_FLOAT, sizeof(float),
|
||||
src0_ne, src0_nb, 2);
|
||||
aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr,
|
||||
ACL_FLOAT, sizeof(float),
|
||||
src1_ne, src1_nb, 2);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr,
|
||||
ACL_FLOAT, sizeof(float),
|
||||
dst_ne, dst_nb, 2);
|
||||
|
||||
src0_tensor_vec.push_back(acl_src0);
|
||||
src1_tensor_vec.push_back(acl_src1);
|
||||
dst_tensor_vec.push_back(acl_dst);
|
||||
}
|
||||
}
|
||||
|
||||
// GroupedMatmulV2 required tensor_list.size < 128
|
||||
size_t GROUP_SIZE = 128;
|
||||
std::vector<std::vector<aclTensor*>> src0_tensor_vec_vec;
|
||||
std::vector<std::vector<aclTensor*>> src1_tensor_vec_vec;
|
||||
std::vector<std::vector<aclTensor*>> dst_tensor_vec_vec;
|
||||
|
||||
// split and call GroupedMatmulV2
|
||||
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
|
||||
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
|
||||
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
|
||||
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
|
||||
std::vector<aclTensor*> dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end);
|
||||
|
||||
aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size());
|
||||
aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size());
|
||||
aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size());
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list,
|
||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list);
|
||||
|
||||
ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
const enum ggml_type type = dst->src[0]->type;
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
ggml_cann_mul_mat_id_fp(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for mul_mat_id");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -978,6 +978,33 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs sparse expert-based matrix multiplication using the CANN backend.
|
||||
*
|
||||
* @details This function implements a MoE-style batched matrix multiplication, where each input token
|
||||
* is routed to one or more experts, and each expert corresponds to a specific [D, M] weight matrix
|
||||
* in the source tensor `src0`. The routing indices are provided via the `ids` tensor.
|
||||
*
|
||||
* For each token (from `src1`), the function selects the corresponding expert(s) as specified by `ids`,
|
||||
* performs the matrix multiplication with the selected expert's weight submatrix (from `src0`),
|
||||
* and stores the results in `dst`. This operation is optimized and executed on the CANN backend.
|
||||
*
|
||||
* Dimensions:
|
||||
* - src0: [D, M, A, 1], where A is the number of experts
|
||||
* - src1: [D, B, N, 1], where N is batch size and B is the slot count per sample
|
||||
* - ids : [K, N], where K is the number of experts each token is routed to
|
||||
* - dst : [M, K, N, 1], output tensor storing the result of expert × token multiplication
|
||||
*
|
||||
* The function handles two main modes:
|
||||
* - If `ne12 == 1`, a simpler per-token loop is used.
|
||||
* - TODO: If `ne12 > 1`, grouped multiplication and memory copying is used for efficiency.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the expert-weighted token outputs are stored.
|
||||
* Expected to be of shape [M, K, N, 1].
|
||||
*/
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
|
||||
/**
|
||||
* @brief Applies a element-wise operation to two input tensors using the CANN
|
||||
* backend.
|
||||
|
||||
@@ -1672,7 +1672,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
ggml_cann_mul_mat(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return false;
|
||||
ggml_cann_mul_mat_id(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SCALE:
|
||||
ggml_cann_scale(ctx, dst);
|
||||
break;
|
||||
@@ -2030,7 +2031,13 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
}
|
||||
}
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return false;
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
// embedding
|
||||
case GGML_OP_GET_ROWS: {
|
||||
switch (op->src[0]->type) {
|
||||
|
||||
@@ -385,9 +385,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.5.0")
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.6.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "ea22e1aefb800e9bc8c74d91633cc58e")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "75b4ad68f25ab673dcc01065e5a0b05f")
|
||||
|
||||
if (POLICY CMP0135)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
@@ -428,6 +428,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
${KLEIDIAI_SRC}/kai/ukernels/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
|
||||
|
||||
set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
|
||||
@@ -438,17 +439,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
|
||||
string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
|
||||
|
||||
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
|
||||
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
|
||||
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
|
||||
|
||||
if (NOT DOTPROD_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
|
||||
endif()
|
||||
|
||||
if (NOT I8MM_ENABLED MATCHES -1)
|
||||
@@ -456,9 +459,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
|
||||
if (NOT SME_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
|
||||
set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c)
|
||||
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
|
||||
endif()
|
||||
|
||||
set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
|
||||
|
||||
@@ -8519,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
#ifdef __ARM_FEATURE_MATMUL_INT8
|
||||
assert((nrc == 2) || (nrc == 1));
|
||||
#else
|
||||
assert(nrc == 1);
|
||||
#endif
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
@@ -8530,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (nrc == 2) {
|
||||
const block_q6_K * GGML_RESTRICT x0 = x;
|
||||
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
|
||||
const block_q8_K * GGML_RESTRICT y0 = y;
|
||||
const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
|
||||
|
||||
float32x4_t vfsum = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
|
||||
const uint8_t * GGML_RESTRICT ql0 = x0->ql;
|
||||
const uint8_t * GGML_RESTRICT ql1 = x1->ql;
|
||||
const uint8_t * GGML_RESTRICT qh0 = x0->qh;
|
||||
const uint8_t * GGML_RESTRICT qh1 = x1->qh;
|
||||
const int8_t * GGML_RESTRICT qy0 = y0->qs;
|
||||
const int8_t * GGML_RESTRICT qy1 = y1->qs;
|
||||
|
||||
const uint8x16_t mone = vdupq_n_u8(0x30);
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
||||
int32x4_t visum = vdupq_n_s32(0);
|
||||
|
||||
// process 8 blocks per iteration, totally 16 blocks
|
||||
for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
|
||||
int8x16_t vx0[8], vx1[8];
|
||||
|
||||
// de-quantize vx0[8]
|
||||
{
|
||||
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
|
||||
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
|
||||
|
||||
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
||||
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
||||
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
||||
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
||||
vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
||||
vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
||||
vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
||||
|
||||
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
||||
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
||||
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
||||
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
||||
vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
||||
vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
||||
vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
||||
}
|
||||
|
||||
// de-quantize vx1[8]
|
||||
{
|
||||
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
|
||||
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
|
||||
|
||||
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
||||
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
||||
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
||||
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
||||
vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
||||
vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
||||
vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
||||
|
||||
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
||||
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
||||
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
||||
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
||||
vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
||||
vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
||||
vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
||||
}
|
||||
|
||||
// process 16 elements (one block with same scale) per iteration
|
||||
// - vx = concat(ql, qh) - 32
|
||||
// - r1,r2,r3,r4 = smmla(vx, vy)
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const int blk = j * 8 + k;
|
||||
|
||||
const int8x16_t vy0 = vld1q_s8(qy0);
|
||||
const int8x16_t vy1 = vld1q_s8(qy1);
|
||||
qy0 += 16;
|
||||
qy1 += 16;
|
||||
|
||||
const int32x4_t block_scale = {
|
||||
x0->scales[blk],
|
||||
x0->scales[blk],
|
||||
x1->scales[blk],
|
||||
x1->scales[blk],
|
||||
};
|
||||
|
||||
// calculate four results at once with outer product
|
||||
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
||||
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
||||
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
||||
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
||||
int32x4_t vr = vdupq_n_s32(0);
|
||||
vr = vmmlaq_s32(vr, vx_l, vy_l);
|
||||
vr = vmmlaq_s32(vr, vx_h, vy_h);
|
||||
|
||||
// apply block scale, will NOT overflow
|
||||
// block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
|
||||
visum = vmlaq_s32(visum, vr, block_scale);
|
||||
}
|
||||
}
|
||||
|
||||
// adjust bias, apply superblock scale
|
||||
{
|
||||
int32_t bias[4];
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
||||
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
|
||||
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
|
||||
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
|
||||
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
|
||||
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
|
||||
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
|
||||
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
|
||||
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
|
||||
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
|
||||
const svint64_t zero = svdup_n_s64(0);
|
||||
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
|
||||
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
|
||||
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
|
||||
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
|
||||
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
|
||||
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
|
||||
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
|
||||
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
|
||||
#else
|
||||
// NEON doesn't support int16 dot product, fallback to separated mul and add
|
||||
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
|
||||
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
|
||||
|
||||
int8x16_t scales_s8 = vld1q_s8(x0->scales);
|
||||
const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
||||
scales_s8 = vld1q_s8(x1->scales);
|
||||
const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
||||
|
||||
int32x4_t prod;
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
|
||||
bias[0] = vaddvq_s32(prod);
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
|
||||
bias[1] = vaddvq_s32(prod);
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
|
||||
bias[2] = vaddvq_s32(prod);
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
|
||||
bias[3] = vaddvq_s32(prod);
|
||||
|
||||
#endif
|
||||
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
|
||||
|
||||
const float32x4_t superblock_scale = {
|
||||
GGML_FP16_TO_FP32(x0->d) * y0->d,
|
||||
GGML_FP16_TO_FP32(x0->d) * y1->d,
|
||||
GGML_FP16_TO_FP32(x1->d) * y0->d,
|
||||
GGML_FP16_TO_FP32(x1->d) * y1->d,
|
||||
};
|
||||
|
||||
visum = vsubq_s32(visum, vibias);
|
||||
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
|
||||
}
|
||||
}
|
||||
|
||||
// vfsum = ABCD -> ACBD
|
||||
// AC -> s, BD -> (s+bs)
|
||||
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
|
||||
vst1_f32(s, vget_low_f32 (vfsum));
|
||||
vst1_f32(s + bs, vget_high_f32(vfsum));
|
||||
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||
float sum = 0;
|
||||
|
||||
@@ -282,7 +282,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||||
.from_float = quantize_row_q6_K,
|
||||
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||
.nrows = 2,
|
||||
#else
|
||||
.nrows = 1,
|
||||
#endif
|
||||
},
|
||||
[GGML_TYPE_IQ2_XXS] = {
|
||||
.from_float = NULL,
|
||||
|
||||
@@ -4,16 +4,22 @@
|
||||
|
||||
// KleidiAI micro-kernels
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
||||
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
||||
|
||||
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
||||
|
||||
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
||||
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
||||
|
||||
#include "kai_common.h"
|
||||
|
||||
#include "kernels.h"
|
||||
@@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
{
|
||||
/* SME GEMM */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
},
|
||||
/* SME GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_F16,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#if defined(__APPLE__)
|
||||
@@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
@@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#else
|
||||
@@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
@@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
||||
ggml_kleidiai_kernels * kernel = nullptr;
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
||||
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
||||
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
|
||||
gemm_gemv_kernels[i].op_type == tensor->type) {
|
||||
kernel = &gemm_gemv_kernels[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
|
||||
@@ -4,6 +4,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <variant>
|
||||
#include "ggml.h"
|
||||
|
||||
enum cpu_feature {
|
||||
CPU_FEATURE_NONE = 0,
|
||||
CPU_FEATURE_DOTPROD = 1,
|
||||
@@ -26,26 +30,53 @@ struct kernel_info {
|
||||
size_t (*get_nr)(void);
|
||||
size_t (*get_kr)(void);
|
||||
size_t (*get_sr)(void);
|
||||
size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
|
||||
size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
|
||||
std::variant<
|
||||
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
||||
std::function<size_t(size_t m_idx, size_t k)>
|
||||
> get_lhs_offset;
|
||||
std::variant<
|
||||
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
||||
std::function<size_t(size_t n_idx, size_t k)>
|
||||
> get_rhs_packed_offset;
|
||||
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
|
||||
size_t (*get_dst_size)(size_t m, size_t n);
|
||||
void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
|
||||
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
|
||||
std::variant<
|
||||
std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
|
||||
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
|
||||
std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
|
||||
size_t dst_stride_col, float clamp_min, float clamp_max)>
|
||||
> run_kernel;
|
||||
};
|
||||
|
||||
struct lhs_packing_info {
|
||||
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
|
||||
size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
||||
size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
||||
void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
||||
size_t lhs_stride, void* lhs_packed);
|
||||
std::variant<
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
|
||||
> get_packed_offset;
|
||||
std::variant<
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
|
||||
std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
|
||||
> packed_size;
|
||||
std::variant<
|
||||
std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
||||
size_t lhs_stride, void* lhs_packed)>,
|
||||
std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
|
||||
void* lhs_packed)>
|
||||
> pack_func;
|
||||
};
|
||||
|
||||
struct rhs_packing_info {
|
||||
size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
|
||||
void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
||||
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
|
||||
std::variant<
|
||||
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
||||
std::function<size_t(size_t n, size_t k)>
|
||||
> packed_size;
|
||||
std::variant<
|
||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
||||
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
||||
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
||||
> pack_func;
|
||||
};
|
||||
|
||||
struct ggml_kleidiai_kernels {
|
||||
@@ -55,6 +86,10 @@ struct ggml_kleidiai_kernels {
|
||||
rhs_packing_info rhs_info;
|
||||
|
||||
cpu_feature required_cpu;
|
||||
ggml_type lhs_type;
|
||||
ggml_type rhs_type;
|
||||
ggml_type op_type;
|
||||
};
|
||||
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
//
|
||||
#include <arm_neon.h>
|
||||
#include <assert.h>
|
||||
#include <atomic>
|
||||
#include <cfloat>
|
||||
#include <stdexcept>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
#if defined(__linux__)
|
||||
@@ -34,8 +36,9 @@
|
||||
#include "ggml-common.h"
|
||||
|
||||
struct ggml_kleidiai_context {
|
||||
cpu_feature features;
|
||||
ggml_kleidiai_kernels * kernels;
|
||||
} static ctx = { NULL };
|
||||
} static ctx = { CPU_FEATURE_NONE, NULL };
|
||||
|
||||
static void init_kleidiai_context(void) {
|
||||
|
||||
@@ -47,18 +50,18 @@ static void init_kleidiai_context(void) {
|
||||
const char *env_var = getenv("GGML_KLEIDIAI_SME");
|
||||
int sme_enabled = 0;
|
||||
|
||||
cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
||||
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
||||
|
||||
if (env_var) {
|
||||
sme_enabled = atoi(env_var);
|
||||
}
|
||||
|
||||
if (sme_enabled != 0) {
|
||||
features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
||||
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
||||
}
|
||||
ctx.kernels = ggml_kleidiai_select_kernels(features);
|
||||
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
||||
}
|
||||
ggml_critical_section_end();
|
||||
}
|
||||
@@ -68,95 +71,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
||||
return tensor->ne[dim];
|
||||
}
|
||||
|
||||
template<typename Ret, typename Variant, typename... Args>
|
||||
static Ret variant_call(const Variant & var, Args&&... args) {
|
||||
return std::visit([&](auto&& func) -> Ret {
|
||||
if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
|
||||
return func(std::forward<Args>(args)...);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid function type in variant_call");
|
||||
}
|
||||
}, var);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::kleidiai {
|
||||
|
||||
static size_t round_down(size_t x, size_t y) {
|
||||
return y == 0 ? x : x - (x % y);
|
||||
}
|
||||
|
||||
static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
|
||||
size_t src_stride = rhs_stride / sizeof(uint16_t);
|
||||
size_t dst_stride = n;
|
||||
|
||||
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
for (size_t n_idx = 0; n_idx < n; ++n_idx) {
|
||||
uint16_t v = *(src + k_idx + n_idx * src_stride);
|
||||
*(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
||||
GGML_ASSERT(kernels);
|
||||
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||
|
||||
size_t k = op->src[0]->ne[0];
|
||||
size_t n = op->src[0]->ne[1];
|
||||
size_t m = op->src[1]->ne[1];
|
||||
|
||||
size_t mr = kernel->get_mr();
|
||||
size_t kr = kernel->get_kr();
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
|
||||
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
||||
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
|
||||
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
||||
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
|
||||
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
|
||||
k * n * sizeof(float) + n * sizeof(float);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
||||
if (dst->op == GGML_OP_MUL_MAT) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||
return compute_forward_q4_0(params, dst);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
return compute_forward_kv_cache(params, dst);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
|
||||
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
|
||||
lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(kernel);
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
GGML_ASSERT(kernels);
|
||||
|
||||
const size_t k = ne00;
|
||||
const size_t m = ne11;
|
||||
const size_t n = ne01;
|
||||
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||
GGML_ASSERT(kernel);
|
||||
|
||||
const size_t n_step = kernel->get_n_step();
|
||||
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
||||
const size_t n_start = ith * num_n_per_thread;
|
||||
const int nth = params->nth;
|
||||
const int ith = params->ith;
|
||||
|
||||
size_t n_to_process = num_n_per_thread;
|
||||
if ((n_start + n_to_process) > n) {
|
||||
n_to_process = n - n_start;
|
||||
const int64_t lhs_batch_size0 = ne12;
|
||||
const int64_t rhs_batch_size0 = ne02;
|
||||
const int64_t batch_size = rhs_batch_size0;
|
||||
|
||||
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
||||
|
||||
const int64_t m = ne11 * r;
|
||||
const int64_t n = ne01;
|
||||
const int64_t k = ne00;
|
||||
|
||||
const size_t lhs_stride = src1->nb[1];
|
||||
const size_t rhs_stride = src0->nb[1];
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
|
||||
const int64_t mr = static_cast<int64_t>(kernel->get_mr());
|
||||
const int64_t nr = static_cast<int64_t>(kernel->get_nr());
|
||||
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
|
||||
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
|
||||
|
||||
const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
|
||||
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
|
||||
const size_t kxn_size = k * n * sizeof(float);
|
||||
const size_t bias_size = n * sizeof(float);
|
||||
|
||||
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
|
||||
GGML_ASSERT(wsize_required <= params->wsize);
|
||||
|
||||
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
|
||||
uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
|
||||
uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
|
||||
uint8_t * bias = rhs_kxn + kxn_size;
|
||||
|
||||
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||
const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
|
||||
const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
|
||||
uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
|
||||
|
||||
// LHS packing
|
||||
{
|
||||
const int64_t m_roundup_mr = kai_roundup(m, mr);
|
||||
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
|
||||
|
||||
if (ith < num_threads) {
|
||||
const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
|
||||
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
|
||||
|
||||
const int64_t m_start = ith * num_m_per_thread0;
|
||||
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
||||
|
||||
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
|
||||
|
||||
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
|
||||
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
|
||||
|
||||
variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
||||
uint8_t * lhs_packed = (uint8_t*)params->wdata;
|
||||
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
||||
// RHS packing
|
||||
if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
|
||||
// First thread to reach this point handles RHS packing
|
||||
memset(bias, 0, n * sizeof(float));
|
||||
transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
|
||||
reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
|
||||
|
||||
size_t mr = kernel->get_mr();
|
||||
size_t kr = kernel->get_kr();
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
// Calculate number of columns to be processed per thread
|
||||
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
|
||||
const size_t m_start = ith * num_m_per_thread;
|
||||
size_t m_to_process = num_m_per_thread;
|
||||
if ((m_start + m_to_process) > m) {
|
||||
m_to_process = m - m_start;
|
||||
}
|
||||
|
||||
if(m_start < m) {
|
||||
// Transform LHS
|
||||
const size_t src_stride = src1->nb[1];
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
|
||||
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||
|
||||
lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||
variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
|
||||
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// Perform the operation
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
|
||||
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
||||
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
||||
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||
first_to_arrive.clear(std::memory_order_release);
|
||||
|
||||
kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
|
||||
dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
return true;
|
||||
// Perform the matmul
|
||||
{
|
||||
const int64_t m_to_process = m;
|
||||
const int64_t m_start = 0;
|
||||
|
||||
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
|
||||
const int64_t num_threads = KAI_MIN(n / n_step, nth);
|
||||
|
||||
if (ith < num_threads) {
|
||||
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
|
||||
const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
|
||||
|
||||
const int64_t n_start = ith * num_n_per_thread0;
|
||||
const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
||||
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
|
||||
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
|
||||
const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
|
||||
|
||||
const void * lhs_ptr = lhs_packed + lhs_packed_offset;
|
||||
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
|
||||
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
|
||||
|
||||
variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
if (batch_idx != batch_size - 1) {
|
||||
// This barrier is necessary when the batch size is larger than 1. While processing a batch,
|
||||
// the work data buffer (params->wdata) is used as temporary storage which means that only
|
||||
// a single batch can be processed at any given time. No barrier is needed for the last
|
||||
// batch since GGML inserts a barrier between the execution of every operator.
|
||||
ggml_barrier(params->threadpool);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
GGML_ASSERT(kernels);
|
||||
|
||||
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = &kernels->lhs_info;
|
||||
|
||||
GGML_ASSERT(kernel);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const size_t k = ne00;
|
||||
const size_t m = ne11;
|
||||
const size_t n = ne01;
|
||||
|
||||
size_t mr = kernel->get_mr();
|
||||
size_t kr = kernel->get_kr();
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
||||
uint8_t * lhs_packed = (uint8_t*)params->wdata;
|
||||
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
||||
|
||||
const size_t n_step = kernel->get_n_step();
|
||||
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
||||
const size_t n_start = ith * num_n_per_thread;
|
||||
|
||||
size_t n_to_process = num_n_per_thread;
|
||||
if ((n_start + n_to_process) > n) {
|
||||
n_to_process = n - n_start;
|
||||
}
|
||||
|
||||
// Calculate number of columns to be processed per thread
|
||||
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
|
||||
const size_t m_start = ith * num_m_per_thread;
|
||||
size_t m_to_process = num_m_per_thread;
|
||||
if ((m_start + m_to_process) > m) {
|
||||
m_to_process = m - m_start;
|
||||
}
|
||||
|
||||
if (m_start < m) {
|
||||
// Transform LHS
|
||||
const size_t src_stride = src1->nb[1];
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
|
||||
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||
|
||||
variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// Perform the operation
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
|
||||
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
||||
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
||||
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||
|
||||
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
||||
sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -169,13 +352,13 @@ public:
|
||||
size_t sr = ctx.kernels->gemm.get_sr();
|
||||
|
||||
#ifndef NDEBUG
|
||||
const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
|
||||
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
||||
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
||||
#endif
|
||||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms);
|
||||
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -189,7 +372,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
|
||||
}
|
||||
} // namespace ggml::cpu::kleidiai
|
||||
|
||||
GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
||||
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
||||
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
@@ -238,12 +421,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
||||
namespace ggml::cpu::kleidiai {
|
||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||
if ( op->op == GGML_OP_MUL_MAT &&
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
||||
op->src[0]->buffer &&
|
||||
(ggml_n_dims(op->src[0]) == 2) &&
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
|
||||
) {
|
||||
if (op->op == GGML_OP_MUL_MAT &&
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
||||
op->src[0]->buffer &&
|
||||
(ggml_n_dims(op->src[0]) == 2) &&
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
@@ -260,6 +442,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
||||
}
|
||||
else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
|
||||
op->src[0]->op == GGML_OP_VIEW &&
|
||||
(op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
|
||||
op->src[1]->ne[1] > 1) {
|
||||
if ((op->src[0]->nb[0] != 2) ||
|
||||
(op->src[1]->nb[0] != 4) ||
|
||||
(op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
||||
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -1,47 +1,61 @@
|
||||
#include "acc.cuh"
|
||||
|
||||
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb1, const int nb2, int offset) {
|
||||
const int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
|
||||
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
int src1_idx = i - offset;
|
||||
int oz = src1_idx / nb2;
|
||||
int oy = (src1_idx - (oz * nb2)) / nb1;
|
||||
int ox = src1_idx % nb1;
|
||||
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
|
||||
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
|
||||
} else {
|
||||
dst[i] = x[i];
|
||||
|
||||
int64_t src1_idx = i - offset;
|
||||
|
||||
int64_t tmp = src1_idx;
|
||||
const int64_t i13 = tmp / s13;
|
||||
tmp -= i13 * s13;
|
||||
const int64_t i12 = tmp / s12;
|
||||
tmp -= i12 * s12;
|
||||
const int64_t i11 = tmp / s11;
|
||||
tmp -= i11 * s11;
|
||||
const int64_t i10 = tmp;
|
||||
|
||||
float val = x[i];
|
||||
if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
|
||||
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
|
||||
}
|
||||
dst[i] = val;
|
||||
}
|
||||
|
||||
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb1, const int nb2, const int offset, cudaStream_t stream) {
|
||||
int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
|
||||
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
|
||||
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) {
|
||||
const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
|
||||
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
|
||||
|
||||
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
||||
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
||||
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
||||
int offset = dst->op_params[3] / 4; // offset in bytes
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(dst));
|
||||
|
||||
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
|
||||
const int64_t s1 = dst->op_params[0] / sizeof(float);
|
||||
const int64_t s2 = dst->op_params[1] / sizeof(float);
|
||||
const int64_t s3 = dst->op_params[2] / sizeof(float);
|
||||
const int64_t offset = dst->op_params[3] / sizeof(float);
|
||||
|
||||
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream);
|
||||
}
|
||||
|
||||
@@ -678,10 +678,14 @@ void launch_fattn(
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
const bool is_mla = DV == 512; // TODO better parameterization
|
||||
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(V || is_mla);
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
@@ -689,6 +693,10 @@ void launch_fattn(
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
|
||||
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
@@ -713,10 +721,10 @@ void launch_fattn(
|
||||
size_t nb12 = K->nb[2];
|
||||
size_t nb13 = K->nb[3];
|
||||
|
||||
const char * V_data = (const char *) V->data;
|
||||
size_t nb21 = V->nb[1];
|
||||
size_t nb22 = V->nb[2];
|
||||
size_t nb23 = V->nb[3];
|
||||
const char * V_data = V ? (const char *) V->data : nullptr;
|
||||
size_t nb21 = V ? V->nb[1] : nb11;
|
||||
size_t nb22 = V ? V->nb[2] : nb12;
|
||||
size_t nb23 = V ? V->nb[3] : nb13;
|
||||
|
||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
||||
@@ -733,7 +741,7 @@ void launch_fattn(
|
||||
nb13 = nb13*bs*sizeof(half)/ts;
|
||||
}
|
||||
|
||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
|
||||
@@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 32;
|
||||
static constexpr int nbatch_V2 = 32;
|
||||
static constexpr int nbatch_combine = 32;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 40;
|
||||
static constexpr int nbatch_V2 = 40;
|
||||
static constexpr int nbatch_combine = 40;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 48;
|
||||
static constexpr int nbatch_V2 = 48;
|
||||
static constexpr int nbatch_combine = 48;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 56;
|
||||
static constexpr int nbatch_V2 = 56;
|
||||
static constexpr int nbatch_combine = 56;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 64;
|
||||
static constexpr int nbatch_V2 = 64;
|
||||
static constexpr int nbatch_combine = 64;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 128;
|
||||
static constexpr int nbatch_V2 = 128;
|
||||
static constexpr int nbatch_combine = 128;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int cc, const int ncols) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
}
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
#else
|
||||
GGML_UNUSED(ncols);
|
||||
return 128;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> {
|
||||
static constexpr int nwarps_max = 8;
|
||||
static constexpr bool Q_in_reg = false;
|
||||
static constexpr int nstages_target = 1;
|
||||
static constexpr int nbatch_K2 = 160;
|
||||
static constexpr int nbatch_V2 = 128;
|
||||
static constexpr int nbatch_combine = 128;
|
||||
|
||||
static int get_nbatch_K2_host(const int cc, const int ncols) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||
return ncols <= 16 ? 96 : 160;
|
||||
}
|
||||
return ncols <= 16 ? 288 : 160;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int ncols) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
return ncols <= 16 ? 96 : 160;
|
||||
#else
|
||||
return ncols <= 16 ? 288 : 160;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int cc, const int ncols) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
}
|
||||
return ncols <= 16 ? 256 : 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int ncols) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
#else
|
||||
return ncols <= 16 ? 256 : 128;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------------------------------------------------------
|
||||
@@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
||||
|
||||
auto load = [&] __device__ (const int n) {
|
||||
auto load = [&] __device__ (auto n) {
|
||||
const int stride_k = WARP_SIZE >> n;
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
||||
@@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
}
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
@@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int cols_per_warp = ntiles * tile_B::I;
|
||||
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
||||
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
||||
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = c::nbatch_K2 + 4;
|
||||
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
|
||||
const int k_VKQ_0 = kb0 * c::nbatch_fa;
|
||||
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
|
||||
@@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
||||
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
||||
static_assert(!mla, "multi-stage loading not implemented for MLA");
|
||||
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
||||
constexpr bool use_cp_async = true;
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
|
||||
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
||||
} else {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
if (ncols2 > 1 || mask_h2) {
|
||||
@@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
|
||||
const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
|
||||
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
|
||||
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
||||
const int k0_diff = k0_stop - k0_start;
|
||||
|
||||
if (nstages <= 1) {
|
||||
@@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
||||
}
|
||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
|
||||
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
|
||||
const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
|
||||
if (nstages <= 1) {
|
||||
// For MLA K and V have the same data.
|
||||
// Therefore, iterate over V in reverse and re-use the data if possible.
|
||||
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
||||
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
||||
#pragma unroll
|
||||
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
||||
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
|
||||
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
||||
@@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
||||
|
||||
// Calculate VKQ tile:
|
||||
#pragma unroll
|
||||
@@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
||||
|
||||
tile_A A;
|
||||
load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
if (ntiles == 1) {
|
||||
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
||||
} else {
|
||||
@@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
@@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr int cols_per_warp = ntiles * tile_B::I;
|
||||
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
||||
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
||||
|
||||
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
||||
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = c::nbatch_K2 + 4;
|
||||
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
||||
|
||||
extern __shared__ half2 tile_Q[];
|
||||
@@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
||||
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
||||
constexpr bool use_cp_async = true;
|
||||
if (ncols2 > 1 || mask_h2) {
|
||||
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
||||
}
|
||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
|
||||
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
||||
}
|
||||
|
||||
// Iterate over ne11 == previous tokens:
|
||||
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
||||
constexpr bool last_iter = false;
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||
}
|
||||
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
||||
constexpr bool last_iter = true;
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
||||
}
|
||||
@@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
||||
|
||||
constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
|
||||
constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
|
||||
constexpr int tile_stride = nbatch_combine + 4;
|
||||
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
||||
|
||||
@@ -895,6 +1079,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
||||
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
||||
}
|
||||
} else if (np > 1) {
|
||||
// Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
|
||||
// Therefore, all other warps also need to execute a __syncthreads().
|
||||
// Otherwise the points at which warps synchronize with each other would become misaligned.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -1007,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
@@ -1052,6 +1241,14 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
if (ncols1*ncols2 > 32) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
|
||||
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
||||
|
||||
typedef fattn_mma_f16_config<DKQ, DV> c;
|
||||
|
||||
@@ -1062,9 +1259,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int stride_Q1 = nb01 / sizeof(float2);
|
||||
const int stride_Q2 = nb02 / sizeof(float2);
|
||||
const int stride_K = nb11 / sizeof(half2);
|
||||
const int stride_V = nb21 / sizeof(half2);
|
||||
const int stride_mask = nb31 / sizeof(half2);
|
||||
|
||||
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
@@ -1087,10 +1285,11 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
@@ -1099,12 +1298,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
}
|
||||
@@ -1125,10 +1324,11 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
@@ -1136,7 +1336,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
@@ -1162,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
|
||||
typedef fattn_mma_f16_config<DKQ, DV> c;
|
||||
|
||||
constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
|
||||
constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
|
||||
constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
|
||||
|
||||
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
|
||||
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
@@ -1175,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
|
||||
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
|
||||
|
||||
constexpr bool mla = DKQ == 576;
|
||||
|
||||
const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
|
||||
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
|
||||
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
|
||||
|
||||
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
|
||||
static_assert(DV % tile_A::J == 0, "bad DV");
|
||||
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
||||
|
||||
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
||||
|
||||
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
||||
|
||||
@@ -1197,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
@@ -1208,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#ifndef GGML_USE_HIP
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#endif // GGML_USE_HIP
|
||||
static __global__ void flash_attn_vec_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
@@ -48,6 +48,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
if (ncols > 1) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
@@ -91,6 +97,13 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
kqsum_shared[j][threadIdx.x] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ half maskh_shared[ncols*D];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskh_shared[j*D + tid] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
@@ -175,6 +188,35 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
|
||||
// In such cases, skip the KV slice.
|
||||
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
|
||||
#ifndef GGML_USE_HIP
|
||||
bool skip = true;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
|
||||
skip = skip && isinf(tmp.x) && isinf(tmp.y);
|
||||
}
|
||||
}
|
||||
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||
continue;
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
}
|
||||
|
||||
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
||||
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
||||
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
|
||||
@@ -202,7 +244,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
}
|
||||
|
||||
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||
sum += maskh_shared[j*D + i_KQ];
|
||||
|
||||
if (ncols == 1) {
|
||||
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
||||
@@ -335,7 +377,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
constexpr int cols_per_block = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#ifndef GGML_USE_HIP
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
#endif // GGML_USE_HIP
|
||||
static __global__ void flash_attn_vec_ext_f32(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
@@ -60,6 +60,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
if (ncols > 1) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
@@ -104,6 +110,13 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
kqsum_shared[j][threadIdx.x] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ float maskf_shared[ncols*D];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskf_shared[j*D + tid] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
@@ -181,6 +194,34 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
|
||||
// In such cases, skip the KV slice.
|
||||
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
|
||||
#ifndef GGML_USE_HIP
|
||||
bool skip = true;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
skip = skip && isinf(maskf_shared[j*D + i]);
|
||||
}
|
||||
}
|
||||
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||
continue;
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
}
|
||||
|
||||
float kqmax_new_arr[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
@@ -204,7 +245,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
}
|
||||
|
||||
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
sum += maskf_shared[j*D + i_KQ];
|
||||
|
||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||
|
||||
@@ -326,7 +367,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
constexpr int cols_per_block = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
template <int DKQ, int DV, int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if constexpr (ncols2 <= 8) {
|
||||
@@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32/ncols2) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -3222,7 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
||||
if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
|
||||
if (!new_mma_available(cc)) {
|
||||
return false;
|
||||
}
|
||||
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
||||
|
||||
@@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||
@@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||
|
||||
@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
|
||||
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
||||
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
||||
|
||||
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
|
||||
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i1 = blockIdx.y;
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i2 = blockIdx.z % ne2;
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
|
||||
@@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
|
||||
|
||||
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
||||
|
||||
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
||||
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
|
||||
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
|
||||
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
|
||||
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
|
||||
|
||||
// Load 4 floats per thread and calculate max. abs. value between them:
|
||||
@@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda(
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
|
||||
|
||||
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
|
||||
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
|
||||
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
||||
switch (mmq_get_q8_1_ds_layout(type_src0)) {
|
||||
case MMQ_Q8_1_DS_LAYOUT_D4:
|
||||
|
||||
@@ -31,7 +31,7 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
@@ -207,6 +207,10 @@ typedef struct {
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
int32_t sect_0;
|
||||
int32_t sect_1;
|
||||
int32_t sect_2;
|
||||
int32_t sect_3;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
|
||||
@@ -332,6 +332,10 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
||||
@@ -411,6 +415,13 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
|
||||
@@ -1275,6 +1286,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
||||
@@ -1354,6 +1369,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
|
||||
@@ -1637,16 +1659,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
return false;
|
||||
}
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
case GGML_OP_IM2COL:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_POOL_1D:
|
||||
@@ -3826,6 +3839,7 @@ static bool ggml_metal_encode_node(
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
|
||||
// make sure we have one or more position id(ne10) per token(ne02)
|
||||
GGML_ASSERT(ne10 % ne02 == 0);
|
||||
GGML_ASSERT(ne10 >= ne02);
|
||||
@@ -3852,20 +3866,42 @@ static bool ggml_metal_encode_node(
|
||||
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
// mrope
|
||||
const int sect_0 = ((const int32_t *) dst->op_params)[11];
|
||||
const int sect_1 = ((const int32_t *) dst->op_params)[12];
|
||||
const int sect_2 = ((const int32_t *) dst->op_params)[13];
|
||||
const int sect_3 = ((const int32_t *) dst->op_params)[14];
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (!is_neox) {
|
||||
if (is_neox) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
} else if (is_mrope && !is_vision) {
|
||||
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
} else if (is_vision) {
|
||||
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
} else {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
}
|
||||
@@ -3896,6 +3932,10 @@ static bool ggml_metal_encode_node(
|
||||
/*.attn_factor =*/ attn_factor,
|
||||
/*.beta_fast =*/ beta_fast,
|
||||
/*.beta_slow =*/ beta_slow,
|
||||
/* sect_0 =*/ sect_0,
|
||||
/* sect_1 =*/ sect_1,
|
||||
/* sect_2 =*/ sect_2,
|
||||
/* sect_3 =*/ sect_3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
@@ -4332,7 +4372,7 @@ static bool ggml_metal_encode_node(
|
||||
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
||||
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
||||
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
||||
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
||||
if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
@@ -4513,6 +4553,24 @@ static bool ggml_metal_encode_node(
|
||||
use_vec_kernel = true;
|
||||
|
||||
switch (ne00) {
|
||||
case 64:
|
||||
{
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
||||
GGML_LOG_ERROR("add template specialization for this type\n");
|
||||
GGML_ABORT("add template specialization for this type");
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case 96:
|
||||
{
|
||||
switch (src1->type) {
|
||||
|
||||
@@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_multi(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < args.n_dims) {
|
||||
const int ic = i0/2;
|
||||
|
||||
// mrope theta calculations
|
||||
// note: the rest is the same as kernel_rope_neox
|
||||
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
|
||||
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
|
||||
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base;
|
||||
if (sector < args.sect_0) {
|
||||
theta_base = (float) pos[i2];
|
||||
} else if (sector < sec_w01) {
|
||||
theta_base = (float) pos[i2 + args.ne02];
|
||||
} else if (sector < sec_w012) {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 2];
|
||||
} else {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 3];
|
||||
}
|
||||
// end of mrope
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_vision(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
|
||||
const int ic = i0/2;
|
||||
|
||||
// mrope theta calculations (only support 2 dimensions)
|
||||
const int sect_dims = args.sect_0 + args.sect_1;
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float p;
|
||||
float theta_base;
|
||||
if (sector < args.sect_1) {
|
||||
p = (float) sector;
|
||||
theta_base = (float) pos[i2];
|
||||
} else {
|
||||
p = (float) sector - args.sect_0;
|
||||
theta_base = (float) pos[i2 + args.ne02];
|
||||
}
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
||||
// end of mrope
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
||||
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
||||
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
|
||||
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
|
||||
|
||||
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
||||
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
||||
@@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
|
||||
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
||||
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
||||
|
||||
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
|
||||
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
|
||||
|
||||
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
||||
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
||||
|
||||
typedef void (im2col_t)(
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
@@ -3109,7 +3255,7 @@ template<
|
||||
typename kd4x4_t, // key type in device memory
|
||||
short nl_k,
|
||||
void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
|
||||
typename vd4x4_t, // key type in device memory
|
||||
typename vd4x4_t, // value type in device memory
|
||||
short nl_v,
|
||||
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
||||
short DK, // K head size
|
||||
@@ -3630,7 +3776,7 @@ template<
|
||||
typename kd4_t, // key type in device memory
|
||||
short nl_k,
|
||||
void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
|
||||
typename vd4_t, // key type in device memory
|
||||
typename vd4_t, // value type in device memory
|
||||
short nl_v,
|
||||
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
||||
short DK, // K head size
|
||||
@@ -3741,6 +3887,11 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
sm[tiisg] = pm[ic + tiisg];
|
||||
}
|
||||
|
||||
// skip -INF blocks
|
||||
if (simd_max(sm[tiisg]) == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
// each simdgroup processes 1 query and NE (NW/NL) head elements
|
||||
@@ -3973,6 +4124,16 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 8>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
||||
|
||||
@@ -4855,8 +4855,6 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
func = ggml_cl_add;
|
||||
break;
|
||||
case GGML_OP_MUL:
|
||||
|
||||
@@ -28,16 +28,19 @@ struct ggml_opt_dataset {
|
||||
};
|
||||
|
||||
struct ggml_opt_context {
|
||||
ggml_backend_sched_t backend_sched = nullptr;
|
||||
ggml_cgraph * allocated_graph = nullptr;
|
||||
ggml_cgraph * allocated_graph_copy = nullptr;
|
||||
struct ggml_context * ctx_static = nullptr;
|
||||
struct ggml_context * ctx_static_cpu = nullptr;
|
||||
struct ggml_context * ctx_compute = nullptr;
|
||||
struct ggml_context * ctx_copy = nullptr;
|
||||
ggml_backend_buffer_t buf_static = nullptr;
|
||||
ggml_backend_buffer_t buf_static_cpu = nullptr;
|
||||
std::mt19937 rng;
|
||||
ggml_backend_sched_t backend_sched = nullptr;
|
||||
ggml_cgraph * allocated_graph = nullptr;
|
||||
ggml_cgraph * allocated_graph_copy = nullptr;
|
||||
struct ggml_context * ctx_static = nullptr;
|
||||
struct ggml_context * ctx_cpu = nullptr;
|
||||
struct ggml_context * ctx_compute = nullptr;
|
||||
struct ggml_context * ctx_copy = nullptr;
|
||||
ggml_backend_buffer_t buf_static = nullptr;
|
||||
ggml_backend_buffer_t buf_cpu = nullptr;
|
||||
std::mt19937 rng;
|
||||
enum ggml_opt_loss_type loss_type;
|
||||
enum ggml_opt_build_type build_type;
|
||||
enum ggml_opt_build_type build_type_alloc;
|
||||
|
||||
struct ggml_tensor * inputs = nullptr;
|
||||
struct ggml_tensor * outputs = nullptr;
|
||||
@@ -50,6 +53,11 @@ struct ggml_opt_context {
|
||||
struct ggml_cgraph * gf = nullptr;
|
||||
struct ggml_cgraph * gb_grad = nullptr;
|
||||
struct ggml_cgraph * gb_opt = nullptr;
|
||||
bool static_graphs = false;
|
||||
bool eval_ready = false;
|
||||
std::vector<struct ggml_tensor *> grad_accs;
|
||||
std::vector<struct ggml_tensor *> grad_m;
|
||||
std::vector<struct ggml_tensor *> grad_v;
|
||||
|
||||
int64_t iter = 1;
|
||||
int32_t opt_period = 1;
|
||||
@@ -73,7 +81,13 @@ struct ggml_opt_result {
|
||||
|
||||
// ====== Dataset ======
|
||||
|
||||
ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
|
||||
ggml_opt_dataset_t ggml_opt_dataset_init(
|
||||
enum ggml_type type_data,
|
||||
enum ggml_type type_label,
|
||||
int64_t ne_datapoint,
|
||||
int64_t ne_label,
|
||||
int64_t ndata,
|
||||
int64_t ndata_shard) {
|
||||
GGML_ASSERT(ne_datapoint > 0);
|
||||
GGML_ASSERT(ne_label >= 0);
|
||||
GGML_ASSERT(ndata > 0);
|
||||
@@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label,
|
||||
result->ctx = ggml_init(params);
|
||||
}
|
||||
|
||||
result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata);
|
||||
result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
|
||||
result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
|
||||
|
||||
if (ne_label > 0) {
|
||||
result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata);
|
||||
result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
|
||||
result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
|
||||
} else {
|
||||
result->labels = nullptr;
|
||||
@@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
|
||||
delete dataset;
|
||||
}
|
||||
|
||||
int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) {
|
||||
return dataset->ndata;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
|
||||
return dataset->data;
|
||||
}
|
||||
@@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
|
||||
GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch));
|
||||
GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
|
||||
GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
|
||||
GGML_ASSERT( data_batch->type == dataset->data->type);
|
||||
GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
|
||||
|
||||
const size_t nb_data_batch = ggml_nbytes(data_batch);
|
||||
GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
|
||||
@@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
|
||||
GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
|
||||
GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
|
||||
|
||||
const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
|
||||
|
||||
GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
|
||||
|
||||
for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
|
||||
const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
|
||||
|
||||
const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data;
|
||||
char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data;
|
||||
memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
|
||||
|
||||
if (!labels_batch) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels;
|
||||
char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels;
|
||||
memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
|
||||
}
|
||||
}
|
||||
|
||||
// ====== Model / Context ======
|
||||
|
||||
struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
|
||||
@@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
|
||||
return *((struct ggml_opt_optimizer_params *) userdata);
|
||||
}
|
||||
|
||||
struct ggml_opt_params ggml_opt_default_params(
|
||||
ggml_backend_sched_t backend_sched,
|
||||
struct ggml_context * ctx_compute,
|
||||
struct ggml_tensor * inputs,
|
||||
struct ggml_tensor * outputs,
|
||||
enum ggml_opt_loss_type loss_type) {
|
||||
return {
|
||||
/*backend_sched =*/ backend_sched,
|
||||
/*ctx_compute =*/ ctx_compute,
|
||||
/*inputs =*/ inputs,
|
||||
/*logits =*/ outputs,
|
||||
/*ctx_compute =*/ nullptr,
|
||||
/*inputs =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*loss_type =*/ loss_type,
|
||||
/*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
|
||||
/*opt_period =*/ 1,
|
||||
@@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
|
||||
return dst;
|
||||
}
|
||||
|
||||
static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
|
||||
GGML_ASSERT(graph);
|
||||
if (opt_ctx->allocated_graph == graph) {
|
||||
return;
|
||||
}
|
||||
static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
|
||||
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
||||
|
||||
ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
|
||||
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
|
||||
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
||||
|
||||
{
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_free(opt_ctx->ctx_copy);
|
||||
opt_ctx->ctx_copy = ggml_init(params);
|
||||
}
|
||||
|
||||
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
|
||||
|
||||
ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
||||
opt_ctx->allocated_graph = graph;
|
||||
}
|
||||
|
||||
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
||||
ggml_opt_context_t result = new struct ggml_opt_context;
|
||||
result->backend_sched = params.backend_sched;
|
||||
result->ctx_compute = params.ctx_compute;
|
||||
result->inputs = params.inputs;
|
||||
result->outputs = params.outputs;
|
||||
result->opt_period = params.opt_period;
|
||||
result->get_opt_pars = params.get_opt_pars;
|
||||
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
||||
|
||||
GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
|
||||
GGML_ASSERT(result->opt_period >= 1);
|
||||
|
||||
const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
|
||||
(params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
|
||||
|
||||
ggml_set_input(result->inputs);
|
||||
ggml_set_output(result->outputs);
|
||||
|
||||
result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
|
||||
ggml_build_forward_expand(result->gf, result->outputs);
|
||||
ggml_set_input(opt_ctx->inputs);
|
||||
ggml_set_output(opt_ctx->outputs);
|
||||
|
||||
int n_param = 0;
|
||||
for (int i = 0; i < result->gf->n_nodes; ++i) {
|
||||
if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
|
||||
for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
|
||||
const struct ggml_tensor * node = opt_ctx->gf->nodes[i];
|
||||
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
||||
n_param++;
|
||||
}
|
||||
GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
|
||||
}
|
||||
|
||||
{
|
||||
if (!opt_ctx->ctx_static) {
|
||||
// The static context is used for:
|
||||
// - gradients (1 tensor per param if using gradient accumulation)
|
||||
// - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
|
||||
// - optimizer momenta (2 tensors per param)
|
||||
// - labels
|
||||
// - loss + its gradient (up to 5 tensors)
|
||||
// - pred
|
||||
// - ncorrect (2 tensors).
|
||||
const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
||||
const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead();
|
||||
// - labels (if using static graphs)
|
||||
// - loss (if using static graphs, up to 5 tensors)
|
||||
// - pred (if using static graphs)
|
||||
// - ncorrect (if using static graphs, 2 tensors).
|
||||
constexpr size_t n_loss = 1;
|
||||
const size_t tensors_per_param = (accumulate ? 1 : 0) +
|
||||
(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
||||
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
||||
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ size_meta,
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
result->ctx_static = ggml_init(params);
|
||||
opt_ctx->ctx_static = ggml_init(params);
|
||||
}
|
||||
GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
|
||||
|
||||
{
|
||||
// The static cpu context is used for:
|
||||
// - optimizer parameters (1 for the entire context)
|
||||
// The cpu context is allocated statically if using static graphs, dynamically otherwise.
|
||||
// It is used for:
|
||||
// - optimizer parameters (1 shared for all optimizer invocations)
|
||||
const size_t size_meta = 1 * ggml_tensor_overhead();
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ size_meta,
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
result->ctx_static_cpu = ggml_init(params);
|
||||
ggml_free(opt_ctx->ctx_cpu);
|
||||
opt_ctx->ctx_cpu = ggml_init(params);
|
||||
|
||||
ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
||||
opt_ctx->buf_cpu = nullptr;
|
||||
}
|
||||
|
||||
struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
|
||||
|
||||
switch (params.loss_type) {
|
||||
switch (opt_ctx->loss_type) {
|
||||
case GGML_OPT_LOSS_TYPE_MEAN: {
|
||||
result->loss = ggml_sum(result->ctx_static, result->outputs);
|
||||
ggml_set_name(result->loss, "loss_sum");
|
||||
const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
|
||||
result->loss = ggml_scale(result->ctx_static, result->loss, scale);
|
||||
ggml_set_name(result->loss, "loss_mean");
|
||||
result->loss_per_datapoint = true;
|
||||
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
|
||||
ggml_set_name(opt_ctx->loss, "loss_sum");
|
||||
const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
|
||||
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
|
||||
ggml_set_name(opt_ctx->loss, "loss_mean");
|
||||
opt_ctx->loss_per_datapoint = true;
|
||||
break;
|
||||
}
|
||||
case GGML_OPT_LOSS_TYPE_SUM: {
|
||||
result->loss = ggml_sum(result->ctx_static, result->outputs);
|
||||
ggml_set_name(result->loss, "loss_sum");
|
||||
result->loss_per_datapoint = false;
|
||||
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
|
||||
ggml_set_name(opt_ctx->loss, "loss_sum");
|
||||
opt_ctx->loss_per_datapoint = false;
|
||||
break;
|
||||
}
|
||||
case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
|
||||
result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
|
||||
ggml_set_input(result->labels);
|
||||
ggml_set_name(result->labels, "labels");
|
||||
result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
|
||||
ggml_set_name(result->loss, "loss_cross_entropy");
|
||||
if (result->opt_period > 1) {
|
||||
result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
|
||||
ggml_set_name(result->loss, "loss_cross_entropy_scaled");
|
||||
opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
|
||||
ggml_set_input(opt_ctx->labels);
|
||||
ggml_set_name(opt_ctx->labels, "labels");
|
||||
opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
|
||||
ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
|
||||
if (opt_ctx->opt_period > 1) {
|
||||
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
|
||||
ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
|
||||
}
|
||||
result->loss_per_datapoint = true;
|
||||
opt_ctx->loss_per_datapoint = true;
|
||||
break;
|
||||
}
|
||||
case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
|
||||
result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
|
||||
ggml_set_input(result->labels);
|
||||
ggml_set_name(result->labels, "labels");
|
||||
result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels);
|
||||
ggml_set_name(result->loss, "loss_error");
|
||||
result->loss = ggml_sqr(result->ctx_static, result->loss);
|
||||
ggml_set_name(result->loss, "loss_squared_error");
|
||||
result->loss = ggml_sum(result->ctx_static, result->loss);
|
||||
ggml_set_name(result->loss, "loss_sum_squared_error");
|
||||
const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
|
||||
result->loss = ggml_scale(result->ctx_static, result->loss, scale);
|
||||
ggml_set_name(result->loss, "loss_mean_squared_error");
|
||||
result->loss_per_datapoint = true;
|
||||
opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
|
||||
ggml_set_input(opt_ctx->labels);
|
||||
ggml_set_name(opt_ctx->labels, "labels");
|
||||
opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
|
||||
ggml_set_name(opt_ctx->loss, "loss_error");
|
||||
opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss);
|
||||
ggml_set_name(opt_ctx->loss, "loss_squared_error");
|
||||
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss);
|
||||
ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
|
||||
const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
|
||||
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
|
||||
ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
|
||||
opt_ctx->loss_per_datapoint = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ggml_set_output(result->loss);
|
||||
ggml_set_loss(result->loss);
|
||||
ggml_build_forward_expand(result->gf, result->loss);
|
||||
ggml_set_output(opt_ctx->loss);
|
||||
ggml_set_loss(opt_ctx->loss);
|
||||
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
|
||||
|
||||
result->pred = ggml_argmax(result->ctx_static, result->outputs);
|
||||
ggml_set_name(result->pred, "pred");
|
||||
ggml_set_output(result->pred);
|
||||
ggml_build_forward_expand(result->gf, result->pred);
|
||||
if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
|
||||
opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs);
|
||||
ggml_set_name(opt_ctx->pred, "pred");
|
||||
ggml_set_output(opt_ctx->pred);
|
||||
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
|
||||
|
||||
if (result->labels) {
|
||||
result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels));
|
||||
ggml_set_name(result->ncorrect, "ncorrect");
|
||||
ggml_set_output(result->ncorrect);
|
||||
ggml_build_forward_expand(result->gf, result->ncorrect);
|
||||
} else {
|
||||
result->ncorrect = nullptr;
|
||||
opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels));
|
||||
ggml_set_name(opt_ctx->ncorrect, "ncorrect");
|
||||
ggml_set_output(opt_ctx->ncorrect);
|
||||
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
|
||||
}
|
||||
|
||||
if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
|
||||
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
|
||||
return result;
|
||||
if (opt_ctx->buf_static) {
|
||||
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
|
||||
return;
|
||||
}
|
||||
} else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) {
|
||||
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
|
||||
opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
||||
return;
|
||||
}
|
||||
|
||||
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
|
||||
result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf);
|
||||
ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
|
||||
if (opt_ctx->grad_accs.empty()) {
|
||||
GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD);
|
||||
|
||||
if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
|
||||
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
|
||||
ggml_graph_reset(result->gb_grad);
|
||||
return result;
|
||||
}
|
||||
const int n_nodes = opt_ctx->gf->n_nodes;
|
||||
opt_ctx->grad_accs.resize(n_nodes);
|
||||
for (int i = 0; i < n_nodes; ++i) {
|
||||
ggml_tensor * node = opt_ctx->gf->nodes[i];
|
||||
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
|
||||
opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
} else {
|
||||
opt_ctx->grad_accs[i] = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT);
|
||||
|
||||
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
||||
result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad);
|
||||
|
||||
result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7);
|
||||
ggml_set_input(result->adamw_params);
|
||||
ggml_set_name(result->adamw_params, "adamw_params");
|
||||
|
||||
for (int i = result->gf->n_nodes-1; i >= 0; --i) {
|
||||
struct ggml_tensor * node = result->gb_opt->nodes[i];
|
||||
struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);
|
||||
|
||||
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
||||
struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node);
|
||||
struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node);
|
||||
struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
|
||||
ggml_build_forward_expand(result->gb_opt, opt_step);
|
||||
if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
|
||||
opt_ctx->grad_m.resize(n_nodes);
|
||||
opt_ctx->grad_v.resize(n_nodes);
|
||||
for (int i = 0; i < n_nodes; ++i) {
|
||||
ggml_tensor * node = opt_ctx->gf->nodes[i];
|
||||
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
||||
opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
} else {
|
||||
opt_ctx->grad_m[i] = nullptr;
|
||||
opt_ctx->grad_v[i] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result->buf_static = ggml_backend_alloc_ctx_tensors(
|
||||
result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
|
||||
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
|
||||
opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
|
||||
ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
|
||||
|
||||
result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
|
||||
if (opt_ctx->buf_static) {
|
||||
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) {
|
||||
return;
|
||||
}
|
||||
} else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) {
|
||||
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
||||
ggml_graph_reset(opt_ctx->gb_grad);
|
||||
}
|
||||
|
||||
ggml_graph_reset(result->gb_opt);
|
||||
GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT);
|
||||
|
||||
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
||||
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
||||
|
||||
opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
|
||||
ggml_set_input(opt_ctx->adamw_params);
|
||||
ggml_set_name(opt_ctx->adamw_params, "adamw_params");
|
||||
|
||||
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
||||
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
||||
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
||||
|
||||
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
|
||||
struct ggml_tensor * m = opt_ctx->grad_m[i];
|
||||
struct ggml_tensor * v = opt_ctx->grad_v[i];
|
||||
struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
|
||||
|
||||
ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
|
||||
ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
|
||||
ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
|
||||
|
||||
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
||||
}
|
||||
}
|
||||
|
||||
if (!opt_ctx->buf_static) {
|
||||
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
|
||||
opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
||||
ggml_graph_reset(opt_ctx->gb_opt);
|
||||
}
|
||||
|
||||
opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type());
|
||||
}
|
||||
|
||||
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
||||
ggml_opt_context_t result = new struct ggml_opt_context;
|
||||
result->backend_sched = params.backend_sched;
|
||||
result->ctx_compute = params.ctx_compute;
|
||||
result->loss_type = params.loss_type;
|
||||
result->build_type = params.build_type;
|
||||
result->build_type_alloc = params.build_type;
|
||||
result->inputs = params.inputs;
|
||||
result->outputs = params.outputs;
|
||||
result->opt_period = params.opt_period;
|
||||
result->get_opt_pars = params.get_opt_pars;
|
||||
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
||||
|
||||
GGML_ASSERT(result->opt_period >= 1);
|
||||
|
||||
result->static_graphs = result->ctx_compute;
|
||||
|
||||
if (!result->static_graphs) {
|
||||
GGML_ASSERT(!result->inputs);
|
||||
GGML_ASSERT(!result->outputs);
|
||||
return result;
|
||||
}
|
||||
|
||||
GGML_ASSERT(result->inputs);
|
||||
GGML_ASSERT(result->outputs);
|
||||
|
||||
result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
|
||||
ggml_build_forward_expand(result->gf, result->outputs);
|
||||
|
||||
ggml_opt_build(result);
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
|
||||
return;
|
||||
}
|
||||
ggml_backend_buffer_free(opt_ctx->buf_static);
|
||||
ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
|
||||
ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
||||
ggml_free(opt_ctx->ctx_static);
|
||||
ggml_free(opt_ctx->ctx_static_cpu);
|
||||
ggml_free(opt_ctx->ctx_cpu);
|
||||
delete opt_ctx;
|
||||
}
|
||||
|
||||
@@ -479,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
|
||||
return opt_ctx->static_graphs;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
|
||||
return opt_ctx->inputs;
|
||||
}
|
||||
@@ -582,8 +683,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl
|
||||
|
||||
// ====== Computation ======
|
||||
|
||||
static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) {
|
||||
if (graph != opt_ctx->gf) {
|
||||
void ggml_opt_prepare_alloc(
|
||||
ggml_opt_context_t opt_ctx,
|
||||
struct ggml_context * ctx_compute,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * inputs,
|
||||
struct ggml_tensor * outputs) {
|
||||
GGML_ASSERT(!opt_ctx->static_graphs);
|
||||
opt_ctx->ctx_compute = ctx_compute;
|
||||
opt_ctx->gf = gf;
|
||||
opt_ctx->inputs = inputs;
|
||||
opt_ctx->outputs = outputs;
|
||||
}
|
||||
|
||||
void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
|
||||
GGML_ASSERT(!opt_ctx->eval_ready);
|
||||
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
|
||||
ggml_graph_reset(opt_ctx->gb_grad);
|
||||
}
|
||||
if (backward) {
|
||||
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
||||
opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;
|
||||
} else {
|
||||
opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD;
|
||||
}
|
||||
|
||||
if (!opt_ctx->static_graphs) {
|
||||
ggml_opt_build(opt_ctx);
|
||||
}
|
||||
|
||||
struct ggml_cgraph * graph = nullptr;
|
||||
switch (opt_ctx->build_type) {
|
||||
case GGML_OPT_BUILD_TYPE_FORWARD: {
|
||||
graph = opt_ctx->gf;
|
||||
} break;
|
||||
case GGML_OPT_BUILD_TYPE_GRAD: {
|
||||
graph = opt_ctx->gb_grad;
|
||||
} break;
|
||||
case GGML_OPT_BUILD_TYPE_OPT: {
|
||||
graph = opt_ctx->gb_opt;
|
||||
} break;
|
||||
}
|
||||
GGML_ASSERT(graph);
|
||||
|
||||
if (opt_ctx->allocated_graph == graph) {
|
||||
opt_ctx->eval_ready = true;
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
|
||||
|
||||
if (opt_ctx->static_graphs) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads),
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_free(opt_ctx->ctx_copy);
|
||||
opt_ctx->ctx_copy = ggml_init(params);
|
||||
|
||||
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
|
||||
} else {
|
||||
opt_ctx->allocated_graph_copy = graph;
|
||||
}
|
||||
|
||||
ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
||||
opt_ctx->allocated_graph = graph;
|
||||
|
||||
opt_ctx->eval_ready = true;
|
||||
}
|
||||
|
||||
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
|
||||
GGML_ASSERT(opt_ctx->eval_ready);
|
||||
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
||||
struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
||||
|
||||
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
||||
@@ -609,9 +781,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
|
||||
adamw_par_data[6] = beta2h;
|
||||
}
|
||||
|
||||
ggml_opt_alloc_graph(opt_ctx, graph);
|
||||
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
||||
opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
|
||||
opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
||||
|
||||
if (!opt_ctx->static_graphs) {
|
||||
opt_ctx->gf = nullptr;
|
||||
opt_ctx->gb_grad = nullptr;
|
||||
opt_ctx->gb_opt = nullptr;
|
||||
opt_ctx->allocated_graph = nullptr;
|
||||
opt_ctx->allocated_graph_copy = nullptr;
|
||||
}
|
||||
|
||||
opt_ctx->eval_ready = false;
|
||||
|
||||
if (!result) {
|
||||
return;
|
||||
@@ -635,12 +817,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
|
||||
ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
|
||||
result->loss.push_back(loss);
|
||||
|
||||
GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
|
||||
std::vector<int32_t> pred(ndata);
|
||||
ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
|
||||
result->pred.insert(result->pred.end(), pred.begin(), pred.end());
|
||||
if (opt_ctx->pred) {
|
||||
GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
|
||||
std::vector<int32_t> pred(ndata);
|
||||
ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
|
||||
result->pred.insert(result->pred.end(), pred.begin(), pred.end());
|
||||
}
|
||||
|
||||
if (!opt_ctx->labels || result->ncorrect < 0) {
|
||||
if (!opt_ctx->ncorrect || result->ncorrect < 0) {
|
||||
result->ncorrect = -1;
|
||||
return;
|
||||
}
|
||||
@@ -652,26 +836,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
|
||||
result->ncorrect += ncorrect;
|
||||
}
|
||||
|
||||
void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
|
||||
ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
|
||||
}
|
||||
|
||||
void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
|
||||
if (opt_ctx->opt_period == 1) {
|
||||
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
||||
if (opt_i_next == 0) {
|
||||
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
|
||||
ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
|
||||
} else {
|
||||
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
|
||||
}
|
||||
opt_ctx->opt_i = opt_i_next;
|
||||
}
|
||||
|
||||
// ====== High-Level Functions ======
|
||||
|
||||
void ggml_opt_epoch(
|
||||
@@ -682,6 +846,7 @@ void ggml_opt_epoch(
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval) {
|
||||
GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
|
||||
struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
|
||||
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
||||
struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
|
||||
@@ -700,16 +865,18 @@ void ggml_opt_epoch(
|
||||
int64_t ibatch = 0;
|
||||
int64_t t_loop_start = ggml_time_us();
|
||||
for (; ibatch < ibatch_split; ++ibatch) {
|
||||
ggml_opt_alloc(opt_ctx, /*backward =*/ true);
|
||||
ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
|
||||
ggml_opt_forward_backward(opt_ctx, result_train);
|
||||
ggml_opt_eval(opt_ctx, result_train);
|
||||
if (callback_train) {
|
||||
callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
|
||||
}
|
||||
}
|
||||
t_loop_start = ggml_time_us();
|
||||
for (; ibatch < nbatches; ++ibatch) {
|
||||
ggml_opt_alloc(opt_ctx, /*backward =*/ false);
|
||||
ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
|
||||
ggml_opt_forward(opt_ctx, result_eval);
|
||||
ggml_opt_eval(opt_ctx, result_eval);
|
||||
if (callback_eval) {
|
||||
callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
|
||||
}
|
||||
@@ -726,13 +893,26 @@ void ggml_opt_epoch_callback_progress_bar(
|
||||
int64_t t_start_us) {
|
||||
fprintf(stderr, "%s[", train ? "train: " : "val: ");
|
||||
|
||||
constexpr int64_t bar_length = 25;
|
||||
// The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
|
||||
constexpr int64_t bar_length = 8;
|
||||
const int64_t ibatch8 = 8 * ibatch;
|
||||
for (int64_t j = 0; j < bar_length; ++j) {
|
||||
const int64_t ibatch_j = ibatch_max * j/bar_length;
|
||||
if (ibatch_j < ibatch) {
|
||||
fprintf(stderr, "=");
|
||||
} else if (ibatch_max * (j - 1)/bar_length < ibatch) {
|
||||
fprintf(stderr, ">");
|
||||
if (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u2588"); // full block
|
||||
} else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u2589"); // 7/8 filled
|
||||
} else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258A"); // 6/8 filled
|
||||
} else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258B"); // 5/8 filled
|
||||
} else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258C"); // 4/8 filled
|
||||
} else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258D"); // 3/8 filled
|
||||
} else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258E"); // 2/8 filled
|
||||
} else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
|
||||
fprintf(stderr, "\u258F"); // 1/8 filled
|
||||
} else {
|
||||
fprintf(stderr, " ");
|
||||
}
|
||||
@@ -764,8 +944,8 @@ void ggml_opt_epoch_callback_progress_bar(
|
||||
const int64_t t_eta_m = t_eta_s / 60;
|
||||
t_eta_s -= t_eta_m * 60;
|
||||
|
||||
fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
|
||||
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
|
||||
fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
|
||||
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
|
||||
idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
|
||||
t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
|
||||
if (ibatch == ibatch_max) {
|
||||
@@ -806,7 +986,10 @@ void ggml_opt_fit(
|
||||
|
||||
int64_t epoch = 1;
|
||||
|
||||
ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
|
||||
ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type);
|
||||
params.ctx_compute = ctx_compute;
|
||||
params.inputs = inputs;
|
||||
params.outputs = outputs;
|
||||
params.opt_period = opt_period;
|
||||
params.get_opt_pars = get_opt_pars;
|
||||
params.get_opt_pars_ud = &epoch;
|
||||
|
||||
@@ -49,35 +49,38 @@ endif()
|
||||
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
|
||||
|
||||
# Link against oneDNN
|
||||
find_package(DNNL)
|
||||
set(GGML_SYCL_DNNL 0)
|
||||
if(DNNL_FOUND)
|
||||
if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR)
|
||||
# Assuming oneDNN packaged with oneapi release is used which
|
||||
# supports only intel target
|
||||
set(DNNL_GPU_VENDOR "INTEL")
|
||||
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
|
||||
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
|
||||
if(GGML_SYCL_DNN)
|
||||
find_package(DNNL)
|
||||
if(DNNL_FOUND)
|
||||
if (NOT DEFINED DNNL_GPU_VENDOR)
|
||||
# default to intel target
|
||||
set(DNNL_GPU_VENDOR "INTEL")
|
||||
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
|
||||
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Verify oneDNN was compiled for the same target as llama
|
||||
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
|
||||
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
|
||||
set(GGML_SYCL_DNNL 1)
|
||||
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
|
||||
foreach(CONFIG ${CONFIGS})
|
||||
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
|
||||
message(STATUS "Found oneDNN: ${DNNL_LIB}")
|
||||
endforeach()
|
||||
# Verify oneDNN was compiled for the same target as llama
|
||||
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
|
||||
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
|
||||
set(GGML_SYCL_DNNL 1)
|
||||
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
|
||||
foreach(CONFIG ${CONFIGS})
|
||||
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
|
||||
message(STATUS "Found oneDNN: ${DNNL_LIB}")
|
||||
endforeach()
|
||||
else()
|
||||
message(WARNING
|
||||
"oneDNN must be compiled for the same target as llama.cpp.
|
||||
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
|
||||
Disabling oneDNN support.")
|
||||
endif()
|
||||
else()
|
||||
message(WARNING
|
||||
"oneDNN must be compiled for the same target as llama.cpp.
|
||||
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
|
||||
Disabling oneDNN support.")
|
||||
message(STATUS "oneDNN not found, disabling oneDNN support")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "oneDNN not found, disabling oneDNN support")
|
||||
message(STATUS "oneDNN support disabled by the user")
|
||||
endif()
|
||||
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
|
||||
|
||||
@@ -108,6 +111,9 @@ endif()
|
||||
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
# Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically
|
||||
# See https://github.com/uxlfoundation/oneMath/issues/654
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
set(SYCL_COMPILER ON)
|
||||
endif()
|
||||
find_package(MKL REQUIRED)
|
||||
target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS)
|
||||
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL)
|
||||
|
||||
@@ -1,93 +1,74 @@
|
||||
#include "binbcast.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s00,*/ int s01, int s02, int s03,
|
||||
/*int s10,*/ int s11, int s12, int s13,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
||||
item_ct1.get_local_id(0)) /
|
||||
ne3;
|
||||
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
||||
item_ct1.get_local_id(0)) %
|
||||
ne3;
|
||||
|
||||
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i11 = i1 % ne11;
|
||||
const int i12 = i2 % ne12;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
for (int i0 = i0s; i0 < ne0;
|
||||
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1,
|
||||
dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) {
|
||||
auto element_id = it.get_global_id(0);
|
||||
auto global_range = it.get_global_range(0);
|
||||
for (; element_id < num_elements; element_id += global_range) {
|
||||
auto src0_float_val = sycl::vec(src0[element_id]).template convert<float, sycl::rounding_mode::rte>();
|
||||
auto src1_float_val = sycl::vec(src1[element_id]).template convert<float, sycl::rounding_mode::rte>();
|
||||
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
|
||||
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
|
||||
dst[element_id] = val_to_store;
|
||||
}
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s00,*/ int s01, int s02, int s03,
|
||||
/*int s10,*/ int s11, int s12, int s13,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13,
|
||||
int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10,
|
||||
int s11, int s12, int s13, std::size_t num_dst_elements,
|
||||
const sycl::nd_item<1> & item_ct1) {
|
||||
auto calculate_logical_index =
|
||||
[](const std::array<int, 4> & dims, std::size_t element_id) __attribute__((always_inline))->std::array<int, 4> {
|
||||
std::array<int, 4> logical_index;
|
||||
#pragma unroll(4)
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
logical_index[i] = element_id % dims[i];
|
||||
element_id /= dims[i];
|
||||
}
|
||||
return logical_index;
|
||||
};
|
||||
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
auto calculate_index = [](const std::array<int, 4> & dims, const std::array<int, 4> & strides,
|
||||
const std::array<int, 4> & indices) __attribute__((always_inline))
|
||||
->std::size_t {
|
||||
std::size_t index = 0;
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto index_i = indices[i];
|
||||
if (indices[i] >= dims[i]) {
|
||||
index_i = indices[i] % dims[i];
|
||||
}
|
||||
index += strides[i] * index_i;
|
||||
}
|
||||
return index;
|
||||
};
|
||||
|
||||
const int i3 = i/(ne2*ne1*ne0);
|
||||
const int i2 = (i/(ne1*ne0)) % ne2;
|
||||
const int i1 = (i/ne0) % ne1;
|
||||
const int i0 = i % ne0;
|
||||
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
auto element_id = item_ct1.get_global_id(0);
|
||||
for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) {
|
||||
auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id);
|
||||
auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index);
|
||||
auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
|
||||
auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index);
|
||||
auto src0_float_val = sycl::vec(src0[src_0_index]).template convert<float, sycl::rounding_mode::rte>();
|
||||
auto src1_float_val = sycl::vec(src1[src_1_index]).template convert<float, sycl::rounding_mode::rte>();
|
||||
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
|
||||
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
|
||||
dst[dst_index] = val_to_store;
|
||||
}
|
||||
|
||||
const int i11 = i1 % ne11;
|
||||
const int i12 = i2 % ne12;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
}
|
||||
|
||||
|
||||
template<float (*bin_op)(const float, const float)>
|
||||
struct bin_bcast_sycl {
|
||||
template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
|
||||
template <typename src0_t, typename src1_t, typename dst_t>
|
||||
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
|
||||
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
|
||||
@@ -96,165 +77,73 @@ struct bin_bcast_sycl {
|
||||
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
|
||||
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
|
||||
int nr0 = ne10 / ne0;
|
||||
int nr1 = ne11/ne1;
|
||||
int nr2 = ne12/ne2;
|
||||
int nr3 = ne13/ne3;
|
||||
|
||||
int nr[4] = { nr0, nr1, nr2, nr3 };
|
||||
|
||||
// collapse dimensions until first broadcast dimension
|
||||
int64_t cne[] = {ne0, ne1, ne2, ne3};
|
||||
int64_t cne0[] = {ne00, ne01, ne02, ne03};
|
||||
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
||||
size_t cnb[] = {nb0, nb1, nb2, nb3};
|
||||
size_t cnb0[] = {nb00, nb01, nb02, nb03};
|
||||
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
||||
auto collapse = [](int64_t cne[]) {
|
||||
cne[0] *= cne[1];
|
||||
cne[1] = cne[2];
|
||||
cne[2] = cne[3];
|
||||
cne[3] = 1;
|
||||
};
|
||||
|
||||
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
|
||||
cnb[1] *= cne[1];
|
||||
cnb[2] *= cne[2];
|
||||
cnb[3] *= cne[3];
|
||||
};
|
||||
|
||||
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
|
||||
auto check_bcast_required = [](const std::array<int64_t, 4> & src_dims,
|
||||
const std::array<int64_t, 4> & dst_dims) -> bool {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (nr[i] != 1) {
|
||||
break;
|
||||
}
|
||||
if (i > 0) {
|
||||
collapse_nb(cnb, cne);
|
||||
collapse_nb(cnb0, cne0);
|
||||
collapse_nb(cnb1, cne1);
|
||||
collapse(cne);
|
||||
collapse(cne0);
|
||||
collapse(cne1);
|
||||
if (dst_dims[i] > src_dims[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
int64_t ne0 = cne[0];
|
||||
int64_t ne1 = cne[1];
|
||||
int64_t ne2 = cne[2];
|
||||
int64_t ne3 = cne[3];
|
||||
return false;
|
||||
};
|
||||
|
||||
int64_t ne10 = cne1[0];
|
||||
int64_t ne11 = cne1[1];
|
||||
int64_t ne12 = cne1[2];
|
||||
int64_t ne13 = cne1[3];
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
size_t nb0 = cnb[0];
|
||||
size_t nb1 = cnb[1];
|
||||
size_t nb2 = cnb[2];
|
||||
size_t nb3 = cnb[3];
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
||||
|
||||
size_t nb00 = cnb0[0];
|
||||
size_t nb01 = cnb0[1];
|
||||
size_t nb02 = cnb0[2];
|
||||
size_t nb03 = cnb0[3];
|
||||
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
||||
|
||||
size_t nb10 = cnb1[0];
|
||||
size_t nb11 = cnb1[1];
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
// dst strides in number of elements
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
size_t s10 = nb10 / sizeof(src1_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
// src1 strides in number of elements
|
||||
size_t s10 = nb10 / sizeof(src0_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
|
||||
size_t s00 = nb00 / sizeof(src0_t);
|
||||
size_t s01 = nb01 / sizeof(src0_t);
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
// src0 strides in number of elements
|
||||
size_t s00 = nb00 / sizeof(src0_t);
|
||||
size_t s01 = nb01 / sizeof(src0_t);
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
|
||||
GGML_UNUSED(s00);
|
||||
std::size_t num_dst_elements = static_cast<std::size_t>(ne0) * static_cast<std::size_t>(ne1) *
|
||||
static_cast<std::size_t>(ne2) * static_cast<std::size_t>(ne3);
|
||||
std::size_t local_range = 256;
|
||||
std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range;
|
||||
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
||||
bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) ||
|
||||
check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 });
|
||||
bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous;
|
||||
|
||||
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
||||
|
||||
sycl::range<3> block_dims(1, 1, 1);
|
||||
block_dims[2] = std::min<unsigned int>(hne0, block_size);
|
||||
block_dims[1] = std::min<unsigned int>(
|
||||
ne1, block_size / (unsigned int)block_dims[2]);
|
||||
block_dims[0] = std::min(
|
||||
std::min<unsigned int>(
|
||||
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
|
||||
(unsigned int)block_dims[1]),
|
||||
64U);
|
||||
|
||||
sycl::range<3> block_nums(
|
||||
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
|
||||
(ne1 + block_dims[1] - 1) / block_dims[1],
|
||||
(hne0 + block_dims[2] - 1) / block_dims[2]);
|
||||
|
||||
if (block_nums[0] > 65535) {
|
||||
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
|
||||
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
||||
sycl::range<3>(1, 1, block_size),
|
||||
sycl::range<3>(1, 1, block_size)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast_unravel<bin_op>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
|
||||
s03, s11, s12, s13, item_ct1);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
/*
|
||||
DPCT1049:16: The work-group size passed to the SYCL kernel may
|
||||
exceed the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if
|
||||
needed.
|
||||
*/
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
||||
ne2, ne3, ne10, ne11, ne12, ne13,
|
||||
s1, s2, s3, s01, s02, s03, s11, s12, s13,
|
||||
item_ct1);
|
||||
});
|
||||
}
|
||||
if (! needs_broadcasting && all_contiguous) {
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
|
||||
k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
|
||||
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1,
|
||||
s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -183,6 +183,24 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
|
||||
const int64_t nb = k / QK_K;
|
||||
const size_t local_size = 32;
|
||||
const size_t global_size = nb * local_size;
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
|
||||
|
||||
cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
@@ -504,7 +522,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
||||
case GGML_TYPE_Q3_K:
|
||||
return dequantize_row_q3_K_sycl;
|
||||
case GGML_TYPE_Q4_K:
|
||||
return dequantize_row_q4_K_sycl;
|
||||
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
return dequantize_row_q4_K_sycl_reorder;
|
||||
} else {
|
||||
return dequantize_row_q4_K_sycl;
|
||||
}
|
||||
case GGML_TYPE_Q5_K:
|
||||
return dequantize_row_q5_K_sycl;
|
||||
case GGML_TYPE_Q6_K:
|
||||
@@ -556,7 +578,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
case GGML_TYPE_Q3_K:
|
||||
return dequantize_row_q3_K_sycl;
|
||||
case GGML_TYPE_Q4_K:
|
||||
return dequantize_row_q4_K_sycl;
|
||||
if (dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
return dequantize_row_q4_K_sycl_reorder;
|
||||
} else {
|
||||
return dequantize_row_q4_K_sycl;
|
||||
}
|
||||
case GGML_TYPE_Q5_K:
|
||||
return dequantize_row_q5_K_sycl;
|
||||
case GGML_TYPE_Q6_K:
|
||||
|
||||
@@ -357,6 +357,28 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename dst_t>
|
||||
inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall,
|
||||
const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) {
|
||||
const int is = 2 * il;
|
||||
constexpr int n = 4;
|
||||
|
||||
uint8_t sc, m;
|
||||
get_scale_min_k4(is + 0, scales_local, sc, m);
|
||||
const float d1 = dall * sc;
|
||||
const float m1 = dmin * m;
|
||||
|
||||
get_scale_min_k4(is + 1, scales_local, sc, m);
|
||||
const float d2 = dall * sc;
|
||||
const float m2 = dmin * m;
|
||||
|
||||
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(qs_ptr + 32 * il + n * ir);
|
||||
for (int l = 0; l < n; ++l) {
|
||||
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
|
||||
y[l + 32] = d2 * (q_vec[l] >> 4) - m2;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
|
||||
@@ -365,36 +387,22 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
|
||||
#if QK_K == 256
|
||||
// assume 32 threads
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t il = tid/8;
|
||||
const int64_t ir = tid%8;
|
||||
const int64_t is = 2*il;
|
||||
const int64_t n = 4;
|
||||
const int64_t il = tid / 8;
|
||||
const int64_t ir = tid % 8;
|
||||
|
||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
|
||||
|
||||
const sycl::half2 dm = x[i].dm;
|
||||
const float dall = dm[0];
|
||||
const float dmin = dm[1];
|
||||
|
||||
if (tid < 12)
|
||||
if (tid < 12) {
|
||||
scales_local[tid] = x[i].scales[tid];
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
uint8_t sc, m;
|
||||
get_scale_min_k4(is + 0, scales_local, sc, m);
|
||||
const float d1 = dall * sc;
|
||||
const float m1 = dmin * m;
|
||||
get_scale_min_k4(is + 1, scales_local, sc, m);
|
||||
const float d2 = dall * sc;
|
||||
const float m2 = dmin * m;
|
||||
|
||||
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
|
||||
for (int l = 0; l < n; ++l) {
|
||||
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
|
||||
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir);
|
||||
#else
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const uint8_t * q = x[i].qs;
|
||||
@@ -406,6 +414,36 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local,
|
||||
const sycl::nd_item<1> & item_ct1, int64_t nb) {
|
||||
const int64_t i = item_ct1.get_group(0); // block index
|
||||
const int64_t tid = item_ct1.get_local_id(0); // thread index within block
|
||||
const int64_t il = tid / 8;
|
||||
const int64_t ir = tid % 8;
|
||||
|
||||
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
|
||||
|
||||
const uint8_t * base = static_cast<const uint8_t *>(vx);
|
||||
const size_t qs_offset = i * (QK_K / 2);
|
||||
const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE;
|
||||
const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2);
|
||||
|
||||
const uint8_t * qs_ptr = base + qs_offset;
|
||||
const uint8_t * scales_ptr = base + scales_offset;
|
||||
ggml_half2 dm_values = *reinterpret_cast<const ggml_half2 *>(base + dm_offset);
|
||||
|
||||
const float dall = dm_values.x();
|
||||
const float dmin = dm_values.y();
|
||||
|
||||
if (tid < 12) {
|
||||
scales_local[tid] = scales_ptr[tid];
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
@@ -1129,7 +1129,13 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
// reorder is currently not supported for dmmv
|
||||
GGML_ABORT("Unimplemented dequantize case case for q4_k reorder");
|
||||
} else {
|
||||
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
|
||||
@@ -655,7 +655,6 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -688,7 +687,6 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -722,7 +720,6 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -754,7 +751,6 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -786,7 +782,6 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -818,7 +813,6 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -850,7 +844,6 @@ inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -883,7 +876,6 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -917,7 +909,6 @@ inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -949,7 +940,6 @@ inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -981,7 +971,6 @@ inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1013,7 +1002,6 @@ inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1045,7 +1033,6 @@ inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1078,7 +1065,6 @@ inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1110,7 +1096,6 @@ inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1142,7 +1127,6 @@ inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1174,7 +1158,6 @@ inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1206,7 +1189,6 @@ inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1241,7 +1223,6 @@ inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1273,7 +1254,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1315,7 +1295,6 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1350,7 +1329,6 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1388,7 +1366,6 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,16 +32,36 @@ public:
|
||||
else static_assert(0);
|
||||
}
|
||||
|
||||
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
|
||||
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
||||
// matrix A has m rows, k columns
|
||||
// matrix B has k rows, n columns
|
||||
// nra - number of elements to skip when moving into next row in A
|
||||
// nrb - number of elements to skip when moving into next row in B
|
||||
// nca - number of elements to skip when moving into next column in A
|
||||
// ncb - number of elements to skip when moving into next column in B
|
||||
// stride_a - number of elements to skip when moving to next A matrix
|
||||
// stride_b - number of elements to skip when moving to next B matrix
|
||||
// batches_a - number of A matrices
|
||||
// batches_b - number of B matrices
|
||||
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
||||
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
|
||||
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
|
||||
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
|
||||
|
||||
auto stream = ctx.stream_dnnl(q);
|
||||
auto eng = ctx.engine_dnnl(q);
|
||||
dnnl::memory::dims a_dims = { m, k };
|
||||
dnnl::memory::dims b_dims = { k, n };
|
||||
dnnl::memory::dims c_dims = { m, n };
|
||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||
|
||||
// { # strides, # rows, # columns }
|
||||
dnnl::memory::dims a_dims = { batches_a, m, k };
|
||||
dnnl::memory::dims b_dims = { batches_b, k, n };
|
||||
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
|
||||
|
||||
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
|
||||
dnnl::memory::dims a_strides = { stride_a, nra, nca };
|
||||
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
|
||||
|
||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
|
||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
|
||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
|
||||
|
||||
dnnl::primitive_attr primitive_attr;
|
||||
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
@@ -63,6 +83,15 @@ public:
|
||||
|
||||
matmul_prim.execute(stream, matmul_args);
|
||||
}
|
||||
|
||||
// matrices A and B are column major, both having k rows
|
||||
// matrix A has m column, matrix B has n columns
|
||||
// output: column major matrix C = A transposed * B
|
||||
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
||||
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
||||
|
||||
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
|
||||
int g_ggml_sycl_debug = 0;
|
||||
int g_ggml_sycl_disable_optimize = 0;
|
||||
int g_ggml_sycl_disable_graph = 0;
|
||||
int g_ggml_sycl_disable_dnn = 0;
|
||||
int g_ggml_sycl_prioritize_dmmv = 0;
|
||||
|
||||
static ggml_sycl_device_info ggml_sycl_init() {
|
||||
@@ -196,12 +197,22 @@ static void ggml_check_sycl() try {
|
||||
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
||||
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
|
||||
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
||||
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
|
||||
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||
GGML_LOG_INFO("Running with Environment Variables:\n");
|
||||
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
||||
#ifdef GGML_SYCL_GRAPH
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
|
||||
#endif
|
||||
#if GGML_SYCL_DNNL
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
|
||||
#endif
|
||||
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
||||
GGML_LOG_INFO("Build with Macros:\n");
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
@@ -341,7 +352,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q4_0 && !g_ggml_sycl_disable_optimize) {
|
||||
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
|
||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||
tensor->extra = extra;
|
||||
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
||||
@@ -374,16 +385,17 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
|
||||
ggml_sycl_set_device(ctx->device);
|
||||
auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
|
||||
#ifndef _WIN32
|
||||
// Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU.
|
||||
// This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here.
|
||||
char* host_buf = (char*)malloc(size);
|
||||
char * host_buf = (char *) malloc(size);
|
||||
memcpy(host_buf, data, size);
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
|
||||
.wait()));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait()));
|
||||
free(host_buf);
|
||||
#else
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait()));
|
||||
#endif
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
@@ -1985,19 +1997,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
int id;
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR(id = get_current_device_id()));
|
||||
#if !GGML_SYCL_DNNL
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
const int64_t ne0 = dst->ne[0]; // used by MKL only
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// ldc == nrows of the matrix that cuBLAS writes into
|
||||
int ldc = id == ctx.device ? ne0 : row_diff;
|
||||
#endif
|
||||
int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
bool use_fp16 = true; // TODO(Yu) SYCL capability check
|
||||
@@ -2033,25 +2044,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
: src1_as_f16.get();
|
||||
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
|
||||
|
||||
#if !GGML_SYCL_DNNL
|
||||
const sycl::half alpha_f16 = 1.0f;
|
||||
const sycl::half beta_f16 = 0.0f;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
||||
*stream, oneapi::math::transpose::trans,
|
||||
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
||||
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
||||
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
||||
dpct::library_data_t::real_half)));
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
#else
|
||||
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
|
||||
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
||||
#if GGML_SYCL_DNNL
|
||||
if (!g_ggml_sycl_disable_dnn) {
|
||||
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
|
||||
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
const sycl::half alpha_f16 = 1.0f;
|
||||
const sycl::half beta_f16 = 0.0f;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
||||
*stream, oneapi::math::transpose::trans,
|
||||
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
||||
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
||||
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
||||
dpct::library_data_t::real_half)));
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
||||
@@ -2072,18 +2087,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
|
||||
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
|
||||
|
||||
#if !GGML_SYCL_DNNL
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
|
||||
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
|
||||
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
||||
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
||||
#else
|
||||
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
||||
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
||||
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
||||
#if GGML_SYCL_DNNL
|
||||
if (!g_ggml_sycl_disable_dnn) {
|
||||
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
||||
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
||||
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
|
||||
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
|
||||
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
||||
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
||||
}
|
||||
}
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_ddq_i);
|
||||
@@ -2697,7 +2716,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
|
||||
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
|
||||
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
|
||||
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
|
||||
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
|
||||
@@ -2713,7 +2732,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h
|
||||
|
||||
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
|
||||
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
|
||||
uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);
|
||||
uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
|
||||
|
||||
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
|
||||
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
|
||||
@@ -2726,6 +2745,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
@@ -2766,7 +2786,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
}
|
||||
|
||||
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
||||
char * dst_t = reinterpret_cast<char *>(dst_ddf);
|
||||
|
||||
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
|
||||
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
|
||||
@@ -2783,42 +2802,83 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
|
||||
GGML_ASSERT(ne10 == ne00);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12 / ne02;
|
||||
const int64_t r3 = ne13 / ne03;
|
||||
|
||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
||||
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
||||
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
|
||||
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
||||
} else {
|
||||
const int ne23 = ne12 * ne13;
|
||||
#if GGML_SYCL_DNNL
|
||||
if (!g_ggml_sycl_disable_dnn) {
|
||||
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
|
||||
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
|
||||
|
||||
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
||||
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
||||
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
||||
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
|
||||
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
|
||||
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
|
||||
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
|
||||
};
|
||||
|
||||
sycl::range<3> block_dims(1, ne12, ne13);
|
||||
queue->submit([&](sycl::handler & cgh) {
|
||||
const void ** ptrs_src_get = ptrs_src.get();
|
||||
void ** ptrs_dst_get = ptrs_dst.get();
|
||||
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
||||
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
||||
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
||||
if (r2 == 1 && r3 == 1) {
|
||||
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
|
||||
}
|
||||
else {
|
||||
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
||||
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
|
||||
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
|
||||
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
|
||||
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// iterate over batches from smaller set of matrices (matrix 0)
|
||||
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
|
||||
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
||||
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
|
||||
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
|
||||
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
|
||||
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
||||
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
||||
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
|
||||
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
||||
} else {
|
||||
const int ne23 = ne12 * ne13;
|
||||
|
||||
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
||||
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
||||
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
||||
|
||||
sycl::range<3> block_dims(1, ne12, ne13);
|
||||
queue->submit([&](sycl::handler & cgh) {
|
||||
const void ** ptrs_src_get = ptrs_src.get();
|
||||
void ** ptrs_dst_get = ptrs_dst.get();
|
||||
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
||||
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
||||
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
||||
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
||||
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
||||
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
||||
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
||||
}
|
||||
}
|
||||
} catch (const sycl::exception & exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||
@@ -2841,6 +2901,8 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
case GGML_TYPE_Q4_K:
|
||||
return !g_ggml_sycl_prioritize_dmmv;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -2858,6 +2920,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
|
||||
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_K:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -2883,16 +2946,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||
}
|
||||
}
|
||||
|
||||
static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
|
||||
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
|
||||
dpct::queue_ptr stream) {
|
||||
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
|
||||
.wait()));
|
||||
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
|
||||
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
|
||||
int offset_blks = offset / sizeof(block_q4_0);
|
||||
auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;
|
||||
auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
|
||||
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
|
||||
|
||||
stream->parallel_for(
|
||||
@@ -2906,25 +2969,66 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
*(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
|
||||
}
|
||||
*(d_ptr + ib) = x[ib].d;
|
||||
});
|
||||
}).wait_and_throw();
|
||||
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
|
||||
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
|
||||
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
|
||||
|
||||
const int nblocks = size / sizeof(block_q4_K);
|
||||
|
||||
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
|
||||
|
||||
auto * qs_ptr = data_device;
|
||||
auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
|
||||
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
|
||||
|
||||
stream->parallel_for(nblocks, [=](auto i) {
|
||||
const block_q4_K * x = (const block_q4_K *) tmp_buf;
|
||||
const int ib = i;
|
||||
|
||||
for (int j = 0; j < QK_K / 2; ++j) {
|
||||
qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < K_SCALE_SIZE; ++j) {
|
||||
scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
|
||||
}
|
||||
|
||||
dm_ptr[ib] = x[ib].dm;
|
||||
}).wait_and_throw();
|
||||
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
|
||||
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
char*data_device = (char*)src0->data;
|
||||
uint8_t * data_device = (uint8_t *) src0->data;
|
||||
size_t ncols = src0->ne[0];
|
||||
size_t nrows = src0->ne[1];
|
||||
size_t size = ggml_nbytes(src0);
|
||||
|
||||
reorder_qw(data_device, ncols, nrows, size, 0, stream);
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
reorder_qw_q4_k(data_device, size, 0, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("reorder_qw() called with unsupported type");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
|
||||
return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
|
||||
ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
|
||||
dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
|
||||
dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
|
||||
dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
|
||||
}
|
||||
|
||||
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
|
||||
@@ -2960,8 +3064,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
|
||||
extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
|
||||
}
|
||||
|
||||
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
||||
static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
||||
src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
|
||||
}
|
||||
|
||||
static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
||||
src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
|
||||
int64_t min_compute_capability = INT_MAX;
|
||||
|
||||
@@ -2984,13 +3098,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
}
|
||||
|
||||
// check data types and tensor shapes for custom matrix multiplication kernels:
|
||||
bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
|
||||
bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
|
||||
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
|
||||
|
||||
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
@@ -3041,8 +3151,6 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
|
||||
} else {
|
||||
constexpr bool convert_src1_to_q8_1 = false;
|
||||
// MUL_MAT_SYCL supports reorder
|
||||
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL);
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
|
||||
}
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
@@ -3713,7 +3821,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
|
||||
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
|
||||
|
||||
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
|
||||
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||
model_sycl_graph.end_recording();
|
||||
|
||||
@@ -24,6 +24,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
||||
const int blocks_per_row = ncols / block_traits::qk;
|
||||
constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
|
||||
constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
|
||||
const int nblocks = nrows * (ncols / block_traits::qk);
|
||||
|
||||
static_assert(blocks_per_subgroup > 0);
|
||||
static_assert(block_elements_per_subgroup > 0);
|
||||
@@ -45,7 +46,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
||||
// x block quant index when casting the quants to int
|
||||
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
||||
|
||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs);
|
||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -739,6 +740,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
}
|
||||
}
|
||||
|
||||
static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
||||
const int nrows, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
|
||||
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
||||
constexpr size_t num_subgroups = 16;
|
||||
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
||||
|
||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
|
||||
nrows, nd_item);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
@@ -1035,7 +1057,12 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
} else {
|
||||
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
|
||||
@@ -56,6 +56,28 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
};
|
||||
|
||||
template <> struct block_q_t<GGML_TYPE_Q4_K> {
|
||||
struct traits {
|
||||
static constexpr uint32_t qk = QK_K;
|
||||
static constexpr uint32_t qi = QI4_K;
|
||||
static constexpr uint32_t qr = QR4_K;
|
||||
static constexpr uint32_t vdr_mmvq = 2;
|
||||
};
|
||||
|
||||
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
||||
|
||||
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
auto nblocks = (nrows * (ncols / traits::qk));
|
||||
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
|
||||
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
|
||||
|
||||
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
|
||||
};
|
||||
|
||||
} // namespace ggml_sycl_reordered
|
||||
|
||||
#endif // GGML_SYCL_QUANTS_HPP
|
||||
|
||||
@@ -285,7 +285,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
|
||||
}
|
||||
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
||||
const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) {
|
||||
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
|
||||
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
|
||||
int v[q4_0_traits::vdr_mmvq];
|
||||
@@ -303,6 +303,67 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
|
||||
};
|
||||
};
|
||||
|
||||
static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales,
|
||||
const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1,
|
||||
const int & iqs) {
|
||||
int v[2];
|
||||
int u[2 * QR4_K];
|
||||
float d8[QR4_K];
|
||||
|
||||
v[0] = q4[0];
|
||||
v[1] = q4[4];
|
||||
|
||||
uint16_t aux[2];
|
||||
const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
|
||||
if (j < 2) {
|
||||
aux[0] = scales[j + 0] & 0x3f3f;
|
||||
aux[1] = scales[j + 2] & 0x3f3f;
|
||||
} else {
|
||||
aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
|
||||
aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
|
||||
}
|
||||
|
||||
const uint8_t * sc = (const uint8_t *) aux;
|
||||
const uint8_t * m = sc + 2;
|
||||
|
||||
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
|
||||
|
||||
for (int i = 0; i < QR4_K; ++i) {
|
||||
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
||||
d8[i] = bq8i->ds[0];
|
||||
|
||||
const int * q8 = (const int *) bq8i->qs + ((iqs / 2) % 4);
|
||||
u[2 * i + 0] = q8[0];
|
||||
u[2 * i + 1] = q8[4];
|
||||
}
|
||||
|
||||
return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, dm, d8);
|
||||
}
|
||||
|
||||
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
|
||||
static constexpr ggml_type gtype = GGML_TYPE_Q4_K;
|
||||
|
||||
using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
|
||||
using q4_k_traits = typename q4_k_block::traits;
|
||||
|
||||
float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
||||
const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) {
|
||||
const int ib = ibx_offset / (QK_K / 2);
|
||||
|
||||
const uint8_t * base = static_cast<const uint8_t *>(vbq);
|
||||
const uint8_t * qs = base + ibx_offset;
|
||||
const int total_qs_bytes = nblocks * (QK_K / 2);
|
||||
const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
|
||||
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
|
||||
|
||||
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
|
||||
const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
|
||||
const uint16_t * scales = (const uint16_t *) scs;
|
||||
|
||||
return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs);
|
||||
}
|
||||
};
|
||||
|
||||
#define VDR_Q4_0_Q8_1_MMVQ 2
|
||||
#define VDR_Q4_0_Q8_1_MMQ 4
|
||||
|
||||
@@ -649,52 +710,17 @@ vec_dot_q3_K_q8_1(const void *__restrict__ vbq,
|
||||
return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
|
||||
}
|
||||
|
||||
static __dpct_inline__ float
|
||||
vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
|
||||
const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
|
||||
|
||||
static __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
|
||||
const int & iqs) {
|
||||
#ifndef GGML_QKK_64
|
||||
|
||||
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
|
||||
|
||||
int v[2];
|
||||
int u[2*QR4_K];
|
||||
float d8[QR4_K];
|
||||
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
|
||||
const int * q4 = (const int *) (bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
|
||||
const uint16_t * scales = (const uint16_t *) bq4_K->scales;
|
||||
|
||||
// iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
|
||||
const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
|
||||
|
||||
// iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
|
||||
// iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
|
||||
// iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
|
||||
// iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
|
||||
|
||||
const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
|
||||
v[0] = q4[0];
|
||||
v[1] = q4[4];
|
||||
|
||||
const uint16_t * scales = (const uint16_t *)bq4_K->scales;
|
||||
uint16_t aux[2];
|
||||
const int j = bq8_offset/2;
|
||||
if (j < 2) {
|
||||
aux[0] = scales[j+0] & 0x3f3f;
|
||||
aux[1] = scales[j+2] & 0x3f3f;
|
||||
} else {
|
||||
aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
|
||||
aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
|
||||
}
|
||||
const uint8_t * sc = (const uint8_t *)aux;
|
||||
const uint8_t * m = sc + 2;
|
||||
|
||||
for (int i = 0; i < QR4_K; ++i) {
|
||||
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
||||
d8[i] = bq8i->ds[0];
|
||||
|
||||
const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
|
||||
u[2*i+0] = q8[0];
|
||||
u[2*i+1] = q8[4];
|
||||
}
|
||||
|
||||
return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
|
||||
return vec_dot_q4_K_q8_1_common(q4, scales, bq4_K->dm, bq8_1, iqs);
|
||||
|
||||
#else
|
||||
|
||||
|
||||
@@ -15,6 +15,32 @@ function(detect_host_compiler)
|
||||
set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Function to test shader extension support
|
||||
# Parameters:
|
||||
# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product")
|
||||
# TEST_SHADER_FILE - Path to the test shader file
|
||||
# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result
|
||||
function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE)
|
||||
execute_process(
|
||||
COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error
|
||||
)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*")
|
||||
message(STATUS "${EXTENSION_NAME} not supported by glslc")
|
||||
set(${RESULT_VARIABLE} OFF PARENT_SCOPE)
|
||||
else()
|
||||
message(STATUS "${EXTENSION_NAME} supported by glslc")
|
||||
set(${RESULT_VARIABLE} ON PARENT_SCOPE)
|
||||
add_compile_definitions(${RESULT_VARIABLE})
|
||||
|
||||
# Ensure the extension support is forwarded to vulkan-shaders-gen
|
||||
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON)
|
||||
set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
if (Vulkan_FOUND)
|
||||
message(STATUS "Vulkan found")
|
||||
|
||||
@@ -23,69 +49,40 @@ if (Vulkan_FOUND)
|
||||
../../include/ggml-vulkan.h
|
||||
)
|
||||
|
||||
# Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
set(VULKAN_SHADER_GEN_CMAKE_ARGS
|
||||
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
|
||||
-DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY}
|
||||
)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
|
||||
message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF)
|
||||
else()
|
||||
message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
set(VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS "")
|
||||
if (CMAKE_BUILD_TYPE AND CMAKE_BUILD_TYPE MATCHES "Debug|Release|MinSizeRel|RelWithDebInfo")
|
||||
list(APPEND VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS --config=${CMAKE_BUILD_TYPE})
|
||||
endif()
|
||||
|
||||
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
# Test all shader extensions
|
||||
test_shader_extension_support(
|
||||
"GL_KHR_cooperative_matrix"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
|
||||
"GGML_VULKAN_COOPMAT_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
|
||||
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF)
|
||||
else()
|
||||
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
endif()
|
||||
test_shader_extension_support(
|
||||
"GL_NV_cooperative_matrix2"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
|
||||
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
# Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
test_shader_extension_support(
|
||||
"GL_EXT_integer_dot_product"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
|
||||
"GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
|
||||
message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
|
||||
set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF)
|
||||
else()
|
||||
message(STATUS "GL_EXT_integer_dot_product supported by glslc")
|
||||
set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON)
|
||||
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
endif()
|
||||
|
||||
# Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
|
||||
message(STATUS "GL_EXT_bfloat16 not supported by glslc")
|
||||
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF)
|
||||
else()
|
||||
message(STATUS "GL_EXT_bfloat16 supported by glslc")
|
||||
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON)
|
||||
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
endif()
|
||||
test_shader_extension_support(
|
||||
"GL_EXT_bfloat16"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
|
||||
"GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
|
||||
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
||||
@@ -124,16 +121,8 @@ if (Vulkan_FOUND)
|
||||
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
||||
endif()
|
||||
|
||||
if (NOT CMAKE_CROSSCOMPILING)
|
||||
add_subdirectory(vulkan-shaders)
|
||||
if (MSVC)
|
||||
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
|
||||
string(TOUPPER ${CONFIG} CONFIG)
|
||||
set_target_properties(vulkan-shaders-gen PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||
endforeach()
|
||||
endif()
|
||||
else()
|
||||
# Set up toolchain for host compilation whether cross-compiling or not
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)
|
||||
set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})
|
||||
else()
|
||||
@@ -146,25 +135,31 @@ if (Vulkan_FOUND)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY)
|
||||
set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake)
|
||||
endif()
|
||||
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
|
||||
|
||||
include(ExternalProject)
|
||||
# Native build through ExternalProject_Add
|
||||
ExternalProject_Add(
|
||||
vulkan-shaders-gen
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
|
||||
CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}
|
||||
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
|
||||
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
|
||||
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
|
||||
-DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
|
||||
-DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build .
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
||||
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
||||
)
|
||||
ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
|
||||
else()
|
||||
# For non-cross-compiling, use empty toolchain (use host compiler)
|
||||
set(HOST_CMAKE_TOOLCHAIN_FILE "")
|
||||
endif()
|
||||
|
||||
# Always use ExternalProject_Add approach
|
||||
include(ExternalProject)
|
||||
|
||||
# Add toolchain file if cross-compiling
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE})
|
||||
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
|
||||
endif()
|
||||
|
||||
# Native build through ExternalProject_Add
|
||||
ExternalProject_Add(
|
||||
vulkan-shaders-gen
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
|
||||
CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS}
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build . ${VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS}
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
||||
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
||||
)
|
||||
ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
|
||||
|
||||
set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
|
||||
set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix})
|
||||
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
|
||||
@@ -175,9 +170,8 @@ if (Vulkan_FOUND)
|
||||
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
|
||||
set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen)
|
||||
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
|
||||
endif()
|
||||
# Add build and install dependencies for all builds
|
||||
set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${_ggml_vk_header}
|
||||
|
||||
@@ -288,6 +288,9 @@ struct vk_device_struct {
|
||||
bool coopmat_acc_f32_support {};
|
||||
bool coopmat_acc_f16_support {};
|
||||
bool coopmat_bf16_support {};
|
||||
bool coopmat_support_16x16x16_f16acc {};
|
||||
bool coopmat_support_16x16x16_f32acc {};
|
||||
bool coopmat1_fa_support {};
|
||||
uint32_t coopmat_m;
|
||||
uint32_t coopmat_n;
|
||||
uint32_t coopmat_k;
|
||||
@@ -410,6 +413,13 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
||||
@@ -1588,19 +1598,36 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
||||
);
|
||||
}
|
||||
|
||||
enum FaCodePath {
|
||||
FA_SCALAR,
|
||||
FA_COOPMAT1,
|
||||
FA_COOPMAT2,
|
||||
};
|
||||
|
||||
// number of rows/cols for flash attention shader
|
||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
||||
|
||||
static uint32_t get_fa_num_small_rows(bool scalar) {
|
||||
return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
|
||||
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
||||
// 128 threads split into four subgroups, each subgroup does 1/4
|
||||
// of the Bc dimension.
|
||||
static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
|
||||
static constexpr uint32_t scalar_flash_attention_Bc = 64;
|
||||
static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
|
||||
|
||||
static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
||||
if (path == FA_COOPMAT2) {
|
||||
return flash_attention_num_small_rows;
|
||||
} else {
|
||||
return scalar_flash_attention_num_small_rows;
|
||||
}
|
||||
}
|
||||
|
||||
static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
GGML_UNUSED(clamp);
|
||||
|
||||
if (scalar) {
|
||||
if (path == FA_SCALAR) {
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, 64};
|
||||
} else {
|
||||
@@ -1608,9 +1635,17 @@ static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t cl
|
||||
}
|
||||
}
|
||||
|
||||
if (path == FA_COOPMAT1) {
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
|
||||
} else {
|
||||
return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
|
||||
}
|
||||
}
|
||||
|
||||
// small rows, large cols
|
||||
if (small_rows) {
|
||||
return {get_fa_num_small_rows(scalar), 32};
|
||||
return {get_fa_num_small_rows(FA_COOPMAT2), 32};
|
||||
}
|
||||
|
||||
// small cols to reduce register count
|
||||
@@ -1907,17 +1942,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
||||
};
|
||||
|
||||
auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1};
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
// For scalar, use 128 (arbitrary)
|
||||
uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
|
||||
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
||||
? scalar_flash_attention_workgroup_size
|
||||
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||
@@ -1929,36 +1966,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
||||
};
|
||||
|
||||
#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
||||
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256)
|
||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
|
||||
|
||||
CREATE_FA(GGML_TYPE_F16, f16, true, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
||||
}
|
||||
#endif
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
CREATE_FA(GGML_TYPE_F16, f16, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
|
||||
}
|
||||
#endif
|
||||
#undef CREATE_FA2
|
||||
@@ -1987,25 +2031,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
||||
}
|
||||
#endif
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S], matmul_iq1_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M], matmul_iq1_m_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S], matmul_iq2_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
|
||||
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
||||
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
@@ -2041,17 +2085,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
@@ -2073,47 +2117,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
#endif
|
||||
|
||||
if (device->coopmat_acc_f16_support) {
|
||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
} else {
|
||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
}
|
||||
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||
@@ -2188,13 +2232,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
} \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||
} \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||
} \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
@@ -2208,34 +2258,34 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -2284,13 +2334,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||
@@ -2322,11 +2372,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -3009,6 +3059,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
|
||||
#if defined(VK_KHR_cooperative_matrix)
|
||||
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
||||
|
||||
// coopmat1 fa shader currently assumes 32 invocations per subgroup
|
||||
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
|
||||
device->subgroup_size_control && device->subgroup_min_size <= 32 &&
|
||||
device->subgroup_max_size >= 32;
|
||||
#endif
|
||||
|
||||
if (coopmat2_support) {
|
||||
@@ -3143,6 +3198,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
// Only enable if shape is identical
|
||||
device->coopmat_acc_f32_support = true;
|
||||
}
|
||||
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
||||
device->coopmat_support_16x16x16_f32acc = true;
|
||||
}
|
||||
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
|
||||
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
|
||||
// coopmat sizes not set yet
|
||||
@@ -3155,6 +3213,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
// Only enable if shape is identical
|
||||
device->coopmat_acc_f16_support = true;
|
||||
}
|
||||
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
||||
device->coopmat_support_16x16x16_f16acc = true;
|
||||
}
|
||||
}
|
||||
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
|
||||
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
|
||||
@@ -3656,7 +3717,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
||||
}
|
||||
|
||||
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
|
||||
VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
||||
VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")");
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_matmul_f32;
|
||||
}
|
||||
@@ -3684,7 +3745,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||
|
||||
// MMQ
|
||||
if (src1_type == GGML_TYPE_Q8_1) {
|
||||
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
|
||||
vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
|
||||
|
||||
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
||||
return nullptr;
|
||||
@@ -3724,9 +3785,12 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||
|
||||
if (ctx->device->coopmat2) {
|
||||
assert(src1_type == GGML_TYPE_F16);
|
||||
return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
|
||||
return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc;
|
||||
}
|
||||
return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
||||
if (ctx->device->coopmat_support) {
|
||||
return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
||||
}
|
||||
return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
|
||||
@@ -5688,6 +5752,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
|
||||
const uint32_t acctype = f32acc ? 4 : 2;
|
||||
const uint32_t f16vec4 = 8;
|
||||
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
||||
|
||||
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
|
||||
|
||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
||||
const uint32_t sfsh = Bc * sfshstride * acctype;
|
||||
|
||||
const uint32_t kshstride = D / 4 + 2;
|
||||
const uint32_t ksh = Bc * kshstride * f16vec4;
|
||||
|
||||
const uint32_t slope = Br * sizeof(float);
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
||||
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
||||
@@ -5738,7 +5832,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
assert(q->type == GGML_TYPE_F32);
|
||||
assert(k->type == v->type);
|
||||
|
||||
bool scalar = !ctx->device->coopmat2;
|
||||
FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
|
||||
ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
||||
|
||||
if (path == FA_COOPMAT1) {
|
||||
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
||||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
||||
|
||||
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
|
||||
|
||||
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t gqa_ratio = 1;
|
||||
uint32_t qk_ratio = neq2 / nek2;
|
||||
@@ -5746,9 +5852,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
uint32_t workgroups_y = (uint32_t)neq2;
|
||||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
// For scalar FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat FA, we always use the small size (which is still pretty large for gqa).
|
||||
const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
|
||||
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
||||
uint32_t max_gqa;
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
case FA_COOPMAT1:
|
||||
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
||||
max_gqa = scalar_flash_attention_num_large_rows;
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(0);
|
||||
}
|
||||
|
||||
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
||||
@@ -5761,11 +5879,23 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
|
||||
vk_pipeline *pipelines;
|
||||
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
|
||||
bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
|
||||
bool small_rows = N <= get_fa_num_small_rows(scalar);
|
||||
bool small_rows = N <= get_fa_num_small_rows(path);
|
||||
|
||||
if (scalar) {
|
||||
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
||||
// So use scalar instead.
|
||||
if (small_rows && path == FA_COOPMAT1) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
// scalar is faster than coopmat2 when N==1
|
||||
if (N == 1 && path == FA_COOPMAT2) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
||||
@@ -5777,7 +5907,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
case FA_COOPMAT1:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
@@ -5789,6 +5933,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(0);
|
||||
}
|
||||
assert(pipelines);
|
||||
|
||||
@@ -10123,7 +10270,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
} else if (tensor->op == GGML_OP_CONCAT) {
|
||||
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
|
||||
} else if (tensor->op == GGML_OP_UPSCALE) {
|
||||
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]);
|
||||
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
|
||||
} else if (tensor->op == GGML_OP_SCALE) {
|
||||
const float * params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
|
||||
@@ -10412,7 +10559,8 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
||||
ggml_vk_print_graph_origin(tensor, done);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) {
|
||||
const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f;
|
||||
if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) {
|
||||
first_error[0] = i0;
|
||||
first_error[1] = i1;
|
||||
first_error[2] = i2;
|
||||
@@ -10424,7 +10572,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
||||
// Special case, value is infinite, avoid NaN result in avg_err
|
||||
// NaN also appears in results, if both are nan error is 0
|
||||
if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) {
|
||||
avg_err += std::fabs(correct - result);
|
||||
avg_err += std::fabs(correct - result) / denom;
|
||||
}
|
||||
counter++;
|
||||
}
|
||||
@@ -10459,7 +10607,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
||||
ggml_vk_print_graph_origin(tensor, done);
|
||||
}
|
||||
|
||||
if (avg_err > 0.05 || std::isnan(avg_err)) {
|
||||
if (avg_err > 0.5 || std::isnan(avg_err)) {
|
||||
std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
|
||||
std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
|
||||
if (src0 != nullptr) {
|
||||
|
||||
@@ -5,18 +5,35 @@ find_package (Threads REQUIRED)
|
||||
|
||||
if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling coopmat glslc support")
|
||||
endif()
|
||||
if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling coopmat2 glslc support")
|
||||
endif()
|
||||
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling dot glslc support")
|
||||
endif()
|
||||
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling bfloat16 glslc support")
|
||||
endif()
|
||||
|
||||
set(TARGET vulkan-shaders-gen)
|
||||
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
|
||||
|
||||
# Configure output directories for MSVC builds
|
||||
if(MSVC)
|
||||
# Get the main project's runtime output directory if possible
|
||||
if(DEFINED CMAKE_RUNTIME_OUTPUT_DIRECTORY)
|
||||
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
|
||||
string(TOUPPER ${CONFIG} CONFIG)
|
||||
set_target_properties(${TARGET} PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -9,59 +9,13 @@
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
|
||||
#include "types.comp"
|
||||
#include "flash_attn_base.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
|
||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||
const uint32_t D_per_thread = D / D_split;
|
||||
|
||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split;
|
||||
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
uint32_t ne1;
|
||||
uint32_t ne2;
|
||||
uint32_t ne3;
|
||||
|
||||
uint32_t neq2;
|
||||
uint32_t neq3;
|
||||
uint32_t nek2;
|
||||
uint32_t nek3;
|
||||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t gqa_ratio;
|
||||
uint32_t split_kv;
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer Q {float data_q[];};
|
||||
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
||||
@@ -70,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
||||
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
||||
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
@@ -113,29 +34,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Store column zero. This is used to save per-row m and L values for split_k.
|
||||
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c == 0) {
|
||||
uint32_t offset = iq2 + r;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Load the slope matrix, indexed by Q's dimension 2.
|
||||
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared vec4 tmpshv4[gl_WorkGroupSize.x];
|
||||
shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
||||
shared vec4 tmpshv4[WorkGroupSize];
|
||||
|
||||
shared float masksh[Bc][Br];
|
||||
shared vec4 Qf[Br][D / 4];
|
||||
@@ -145,58 +45,12 @@ void main() {
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
const uint32_t N = p.N;
|
||||
const uint32_t KV = p.KV;
|
||||
init_indices();
|
||||
|
||||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
||||
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
|
||||
|
||||
uint32_t i = gl_WorkGroupID.x;
|
||||
uint32_t split_k_index = 0;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
split_k_index = gl_WorkGroupID.x;
|
||||
}
|
||||
|
||||
const uint32_t Tr = CEIL_DIV(N, Br);
|
||||
|
||||
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
||||
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
||||
|
||||
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
||||
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
||||
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
||||
const uint32_t iq3 = gl_WorkGroupID.z;
|
||||
|
||||
// broadcast factors
|
||||
const uint32_t rk2 = p.neq2/p.nek2;
|
||||
const uint32_t rk3 = p.neq3/p.nek3;
|
||||
|
||||
const uint32_t rv2 = p.neq2/p.nev2;
|
||||
const uint32_t rv3 = p.neq3/p.nev3;
|
||||
|
||||
// k indices
|
||||
const uint32_t ik3 = iq3 / rk3;
|
||||
const uint32_t ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const uint32_t iv3 = iq3 / rv3;
|
||||
const uint32_t iv2 = iq2 / rv2;
|
||||
|
||||
// nb?1 are already divided by the type size and are in units of elements.
|
||||
// When using grouped query attention, Q is indexed by iq2, so the stride
|
||||
// should be nb02 (which is in bytes).
|
||||
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||
uint32_t k_stride = p.nb11;
|
||||
uint32_t v_stride = p.nb21;
|
||||
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||
// that prevents the compiler from folding the "&" through the select
|
||||
// and breaking the alignment detection.
|
||||
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
|
||||
162
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
Normal file
162
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
Normal file
@@ -0,0 +1,162 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
layout (constant_id = 4) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
uint32_t ne1;
|
||||
uint32_t ne2;
|
||||
uint32_t ne3;
|
||||
|
||||
uint32_t neq2;
|
||||
uint32_t neq3;
|
||||
uint32_t nek2;
|
||||
uint32_t nek3;
|
||||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t gqa_ratio;
|
||||
uint32_t split_kv;
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
|
||||
// Store column zero. This is used to save per-row m and L values for split_k.
|
||||
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c == 0) {
|
||||
uint32_t offset = iq2 + r;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Load the slope matrix, indexed by Q's dimension 2.
|
||||
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
|
||||
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
|
||||
q_stride, k_stride, v_stride, m_stride;
|
||||
|
||||
void init_indices()
|
||||
{
|
||||
N = p.N;
|
||||
KV = p.KV;
|
||||
|
||||
i = gl_WorkGroupID.x;
|
||||
split_k_index = 0;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
split_k_index = gl_WorkGroupID.x;
|
||||
}
|
||||
|
||||
Tr = CEIL_DIV(N, Br);
|
||||
|
||||
start_j = split_k_index * p.split_kv / Bc;
|
||||
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
||||
|
||||
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
||||
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
||||
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
||||
iq3 = gl_WorkGroupID.z;
|
||||
|
||||
// broadcast factors
|
||||
rk2 = p.neq2/p.nek2;
|
||||
rk3 = p.neq3/p.nek3;
|
||||
|
||||
rv2 = p.neq2/p.nev2;
|
||||
rv3 = p.neq3/p.nev3;
|
||||
|
||||
// k indices
|
||||
ik3 = iq3 / rk3;
|
||||
ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
iv3 = iq3 / rv3;
|
||||
iv2 = iq2 / rv2;
|
||||
|
||||
// nb?1 are already divided by the type size and are in units of elements.
|
||||
// When using grouped query attention, Q is indexed by iq2, so the stride
|
||||
// should be nb02 (which is in bytes).
|
||||
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||
k_stride = p.nb11;
|
||||
v_stride = p.nb21;
|
||||
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||
// that prevents the compiler from folding the "&" through the select
|
||||
// and breaking the alignment detection.
|
||||
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||
}
|
||||
360
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
Normal file
360
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
Normal file
@@ -0,0 +1,360 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
|
||||
#include "types.comp"
|
||||
#include "flash_attn_base.comp"
|
||||
|
||||
const uint32_t D_per_thread = D / D_split;
|
||||
const uint32_t row_split = 4;
|
||||
const uint32_t rows_per_thread = Br / row_split;
|
||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
|
||||
|
||||
layout (binding = 0) readonly buffer Q {float data_q[];};
|
||||
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
||||
layout (binding = 1) readonly buffer K {float16_t data_k[];};
|
||||
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
||||
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
||||
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * D + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
return elem;
|
||||
}
|
||||
|
||||
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
||||
const uint32_t MatBr = 16;
|
||||
const uint32_t MatBc = 16;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||
|
||||
const uint32_t qstride = D / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
// Avoid padding for D==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
||||
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 ksh[Bc * kshstride];
|
||||
|
||||
shared float slope[Br];
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
init_indices();
|
||||
|
||||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
|
||||
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
||||
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
||||
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
||||
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
|
||||
|
||||
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (D / 4);
|
||||
uint32_t r = (idx + tid) / (D / 4);
|
||||
if (r < Br && d < D / 4 &&
|
||||
i * Br + r < N) {
|
||||
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = ACC_TYPEV4(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
float Lf[rows_per_thread], Mf[rows_per_thread];
|
||||
|
||||
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
|
||||
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lf[r] = 0;
|
||||
Mf[r] = NEG_FLT_MAX_OVER_2;
|
||||
}
|
||||
|
||||
// ALiBi
|
||||
if (p.max_bias > 0.0f) {
|
||||
if (tid < Br) {
|
||||
uint r = tid;
|
||||
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
|
||||
}
|
||||
barrier();
|
||||
} else {
|
||||
if (tid < Br) {
|
||||
uint r = tid;
|
||||
slope[r] = 1.0;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
||||
#else
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (D / 4);
|
||||
uint32_t c = (idx + tid) / (D / 4);
|
||||
if (c < Bc && d < D / 4) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
#else
|
||||
f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
|
||||
ksh[c * kshstride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
|
||||
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
for (uint32_t d = 0; d < D / 16; ++d) {
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
|
||||
}
|
||||
|
||||
uint coord = gl_SubgroupID * MatBc * sfshstride;
|
||||
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
barrier();
|
||||
|
||||
if (p.logit_softcap != 0.0f) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) / Br;
|
||||
uint32_t r = (idx + tid) % Br;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (p.mask != 0) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
float eMf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
|
||||
}
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// P = e^(S - M)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf, Moldf);
|
||||
eMf[r] = exp(Moldf - Mf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lf[r] = eMf[r]*Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
float Pf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
||||
Lf[r] += Pf[r];
|
||||
}
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
// reduce across threads
|
||||
|
||||
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
FLOAT_TYPE M = Mf[r];
|
||||
tmpsh[tid] = M;
|
||||
// Compute max across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||
M = max(M, tmpsh[tid ^ s]);
|
||||
barrier();
|
||||
tmpsh[tid] = M;
|
||||
barrier();
|
||||
}
|
||||
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Moldf[r] = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf[r], Moldf[r]);
|
||||
eMf[r] = exp(Moldf[r] - Mf[r]);
|
||||
|
||||
Lf[r] = eMf[r]*Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
FLOAT_TYPE L = Lf[r];
|
||||
tmpsh[tid] = L;
|
||||
// Compute sum across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||
L += tmpsh[tid ^ s];
|
||||
barrier();
|
||||
tmpsh[tid] = L;
|
||||
barrier();
|
||||
}
|
||||
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
|
||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||
Of[r][d] += tmpshv4[tid ^ s];
|
||||
barrier();
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
barrier();
|
||||
}
|
||||
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
|
||||
barrier();
|
||||
}
|
||||
}
|
||||
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
uint32_t o_offset = D * p.ne1 * split_k_index;
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
float Lfrcp[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] *= float16_t(Lfrcp[r]);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (i * Br + tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,62 +18,12 @@
|
||||
|
||||
#include "types.comp"
|
||||
#include "dequant_funcs_cm2.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 1) const uint32_t Br = 32;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
uint32_t ne1;
|
||||
uint32_t ne2;
|
||||
uint32_t ne3;
|
||||
|
||||
uint32_t neq2;
|
||||
uint32_t neq3;
|
||||
uint32_t nek2;
|
||||
uint32_t nek3;
|
||||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t gqa_ratio;
|
||||
uint32_t split_kv;
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
#include "flash_attn_base.comp"
|
||||
|
||||
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
||||
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
||||
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
|
||||
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
||||
return max(x, y);
|
||||
@@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Store column zero. This is used to save per-row m and L values for split_k.
|
||||
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c == 0) {
|
||||
uint32_t offset = iq2 + r;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Load the slope matrix, indexed by Q's dimension 2.
|
||||
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint32_t N = p.N;
|
||||
const uint32_t KV = p.KV;
|
||||
|
||||
uint32_t i = gl_WorkGroupID.x;
|
||||
uint32_t split_k_index = 0;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
split_k_index = gl_WorkGroupID.x;
|
||||
}
|
||||
|
||||
const uint32_t Tr = CEIL_DIV(N, Br);
|
||||
|
||||
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
||||
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
||||
|
||||
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
||||
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
||||
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
||||
const uint32_t iq3 = gl_WorkGroupID.z;
|
||||
|
||||
// broadcast factors
|
||||
const uint32_t rk2 = p.neq2/p.nek2;
|
||||
const uint32_t rk3 = p.neq3/p.nek3;
|
||||
|
||||
const uint32_t rv2 = p.neq2/p.nev2;
|
||||
const uint32_t rv3 = p.neq3/p.nev3;
|
||||
|
||||
// k indices
|
||||
const uint32_t ik3 = iq3 / rk3;
|
||||
const uint32_t ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const uint32_t iv3 = iq3 / rv3;
|
||||
const uint32_t iv2 = iq2 / rv2;
|
||||
init_indices();
|
||||
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
|
||||
@@ -195,17 +90,6 @@ void main() {
|
||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
||||
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
||||
|
||||
// nb?1 are already divided by the type size and are in units of elements.
|
||||
// When using grouped query attention, Q is indexed by iq2, so the stride
|
||||
// should be nb02 (which is in bytes).
|
||||
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||
uint32_t k_stride = p.nb11;
|
||||
uint32_t v_stride = p.nb21;
|
||||
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||
// that prevents the compiler from folding the "&" through the select
|
||||
// and breaking the alignment detection.
|
||||
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||
{
|
||||
|
||||
@@ -215,7 +215,7 @@ static std::mutex compile_count_mutex;
|
||||
static std::condition_variable compile_count_cond;
|
||||
|
||||
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
||||
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
||||
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
||||
std::string out_fname = join_paths(output_dir, name + ".spv");
|
||||
std::string in_path = join_paths(input_dir, in_fname);
|
||||
|
||||
@@ -424,6 +424,7 @@ void process_shaders() {
|
||||
// flash attention
|
||||
for (const auto& f16acc : {false, true}) {
|
||||
std::string acctype = f16acc ? "float16_t" : "float";
|
||||
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "f32") {
|
||||
@@ -440,6 +441,16 @@ void process_shaders() {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
}
|
||||
#endif
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
|
||||
106
ggml/src/ggml.c
106
ggml/src/ggml.c
@@ -64,12 +64,17 @@
|
||||
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
|
||||
float ggml_table_f32_f16[1 << 16];
|
||||
|
||||
#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \
|
||||
(!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH))
|
||||
#if defined(__linux__) || \
|
||||
defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
|
||||
(defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
|
||||
|
||||
#include <unistd.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/wait.h>
|
||||
#if defined(__linux__)
|
||||
#include <sys/prctl.h>
|
||||
#endif
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#include <unwind.h>
|
||||
@@ -133,10 +138,36 @@ static void ggml_print_backtrace(void) {
|
||||
if (GGML_NO_BACKTRACE) {
|
||||
return;
|
||||
}
|
||||
char attach[32];
|
||||
snprintf(attach, sizeof(attach), "attach %d", getpid());
|
||||
int pid = fork();
|
||||
if (pid == 0) {
|
||||
#if defined(__linux__)
|
||||
FILE * f = fopen("/proc/self/status", "r");
|
||||
size_t size = 0;
|
||||
char * line = NULL;
|
||||
ssize_t length = 0;
|
||||
while ((length = getline(&line, &size, f)) > 0) {
|
||||
if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) &&
|
||||
(length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) {
|
||||
// Already being debugged, and the breakpoint is the later abort()
|
||||
free(line);
|
||||
fclose(f);
|
||||
return;
|
||||
}
|
||||
}
|
||||
free(line);
|
||||
fclose(f);
|
||||
int lock[2] = { -1, -1 };
|
||||
(void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER
|
||||
#endif
|
||||
const int parent_pid = getpid();
|
||||
const int child_pid = fork();
|
||||
if (child_pid < 0) { // error
|
||||
return;
|
||||
} else if (child_pid == 0) { // child
|
||||
char attach[32];
|
||||
snprintf(attach, sizeof(attach), "attach %d", parent_pid);
|
||||
#if defined(__linux__)
|
||||
close(lock[1]);
|
||||
(void) !read(lock[0], lock, 1);
|
||||
#endif
|
||||
// try gdb
|
||||
execlp("gdb", "gdb", "--batch",
|
||||
"-ex", "set style enabled on",
|
||||
@@ -149,18 +180,18 @@ static void ggml_print_backtrace(void) {
|
||||
execlp("lldb", "lldb", "--batch",
|
||||
"-o", "bt",
|
||||
"-o", "quit",
|
||||
"-p", attach,
|
||||
"-p", &attach[sizeof("attach ") - 1],
|
||||
(char *) NULL);
|
||||
exit(EXIT_FAILURE);
|
||||
} else {
|
||||
int wstatus;
|
||||
waitpid(pid, &wstatus, 0);
|
||||
if (WIFEXITED(wstatus)) {
|
||||
if (WEXITSTATUS(wstatus) == EXIT_FAILURE) {
|
||||
// gdb failed, fallback to backtrace_symbols
|
||||
ggml_print_backtrace_symbols();
|
||||
}
|
||||
}
|
||||
// gdb failed, fallback to backtrace_symbols
|
||||
ggml_print_backtrace_symbols();
|
||||
_Exit(0);
|
||||
} else { // parent
|
||||
#if defined(__linux__)
|
||||
prctl(PR_SET_PTRACER, child_pid);
|
||||
close(lock[1]);
|
||||
close(lock[0]);
|
||||
#endif
|
||||
waitpid(child_pid, NULL, 0);
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -5499,7 +5530,7 @@ static void ggml_compute_backward(
|
||||
// tensor = src0 * 1 + src1 * 0
|
||||
if (src0_needs_grads) {
|
||||
// dsrc0 = dtensor * 1
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, grad);
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
|
||||
}
|
||||
if (src1_needs_grads) {
|
||||
// dsrc1 = dtensor * 0 -> noop
|
||||
@@ -5780,10 +5811,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
|
||||
}
|
||||
|
||||
void ggml_build_backward_expand(
|
||||
struct ggml_context * ctx_static,
|
||||
struct ggml_context * ctx_compute,
|
||||
struct ggml_cgraph * cgraph,
|
||||
bool accumulate) {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * cgraph,
|
||||
struct ggml_tensor ** grad_accs) {
|
||||
GGML_ASSERT(cgraph->n_nodes > 0);
|
||||
GGML_ASSERT(cgraph->grads);
|
||||
GGML_ASSERT(cgraph->grad_accs);
|
||||
@@ -5856,21 +5886,24 @@ void ggml_build_backward_expand(
|
||||
GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
|
||||
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
|
||||
|
||||
const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||
GGML_ASSERT(igrad != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
|
||||
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
|
||||
cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
|
||||
cgraph->grads[igrad] = cgraph->grad_accs[igrad];
|
||||
ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
|
||||
const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||
GGML_ASSERT(ihash != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
|
||||
if (grad_accs && grad_accs[i]) {
|
||||
cgraph->grad_accs[ihash] = grad_accs[i];
|
||||
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
|
||||
} else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
|
||||
// loss tensors always need a gradient accumulator
|
||||
cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
||||
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
|
||||
}
|
||||
grads_needed[igrad] = true;
|
||||
grads_needed[ihash] = true;
|
||||
}
|
||||
|
||||
for (int i = n_nodes_f - 1; i >= 0; --i) {
|
||||
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
|
||||
// use allocator to automatically make inplace operations
|
||||
ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
|
||||
ggml_compute_backward(ctx, cgraph, i, grads_needed);
|
||||
}
|
||||
|
||||
free(grads_needed);
|
||||
@@ -6016,8 +6049,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
|
||||
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
|
||||
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
|
||||
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
|
||||
ggml_graph_cpy(cgraph, result);
|
||||
return result;
|
||||
}
|
||||
@@ -6036,6 +6069,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
|
||||
}
|
||||
|
||||
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
||||
if (!cgraph) {
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(cgraph->grads != NULL);
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
@@ -6345,8 +6381,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
|
||||
tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
|
||||
}
|
||||
|
||||
void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
|
||||
GGML_UNUSED(ctx); // TODO: remove this parameter
|
||||
void ggml_set_param(struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor->op == GGML_OP_NONE);
|
||||
tensor->flags |= GGML_TENSOR_FLAG_PARAM;
|
||||
}
|
||||
|
||||
|
||||
@@ -299,10 +299,10 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct
|
||||
return false;
|
||||
}
|
||||
} catch (std::length_error &) {
|
||||
fprintf(stderr, "%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
|
||||
GGML_LOG_ERROR("%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
|
||||
return false;
|
||||
} catch (std::bad_alloc &) {
|
||||
fprintf(stderr, "%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
|
||||
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
|
||||
return false;
|
||||
}
|
||||
kv.emplace_back(key, value);
|
||||
@@ -328,14 +328,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
ok = ok && gr.read(magic, 4);
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read magic\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read magic\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < magic.size(); i++) {
|
||||
if (magic[i] != GGUF_MAGIC[i]) {
|
||||
fprintf(stderr, "%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
|
||||
GGML_LOG_ERROR("%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -348,11 +348,11 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
if (ok && gr.read(ctx->version)) {
|
||||
if (ctx->version == 1) {
|
||||
fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__);
|
||||
GGML_LOG_ERROR("%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
if (ctx->version > GGUF_VERSION) {
|
||||
fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n",
|
||||
GGML_LOG_ERROR("%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n",
|
||||
__func__, ctx->version, GGUF_VERSION);
|
||||
ok = false;
|
||||
}
|
||||
@@ -363,7 +363,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
if (ok && gr.read(n_tensors)) {
|
||||
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
|
||||
if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) {
|
||||
fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
|
||||
GGML_LOG_ERROR("%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
|
||||
__func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info));
|
||||
ok = false;
|
||||
}
|
||||
@@ -374,7 +374,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
if (ok && gr.read(n_kv)) {
|
||||
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
|
||||
if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) {
|
||||
fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
|
||||
GGML_LOG_ERROR("%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
|
||||
__func__, n_kv, SIZE_MAX/sizeof(gguf_kv));
|
||||
ok = false;
|
||||
}
|
||||
@@ -383,7 +383,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read header\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read header\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -399,15 +399,15 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
try {
|
||||
ok = ok && gr.read(key);
|
||||
} catch (std::length_error &) {
|
||||
fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
|
||||
GGML_LOG_ERROR("%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
|
||||
ok = false;
|
||||
} catch (std::bad_alloc &) {
|
||||
fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
|
||||
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
|
||||
ok = false;
|
||||
}
|
||||
for (size_t j = 0; ok && j < ctx->kv.size(); ++j) {
|
||||
if (key == ctx->kv[j].key) {
|
||||
fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
|
||||
GGML_LOG_ERROR("%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
|
||||
ok = false;
|
||||
}
|
||||
}
|
||||
@@ -441,14 +441,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
case GGUF_TYPE_ARRAY:
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
|
||||
GGML_LOG_ERROR("%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
|
||||
ok = false;
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read key-value pairs\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -458,7 +458,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx);
|
||||
|
||||
if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) {
|
||||
fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
|
||||
GGML_LOG_ERROR("%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -474,14 +474,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
try {
|
||||
ok = ok && gr.read(name);
|
||||
} catch (std::length_error &) {
|
||||
fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
|
||||
GGML_LOG_ERROR("%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
|
||||
ok = false;
|
||||
} catch (std::bad_alloc &) {
|
||||
fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
|
||||
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
|
||||
ok = false;
|
||||
}
|
||||
if (name.length() >= GGML_MAX_NAME) {
|
||||
fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
|
||||
GGML_LOG_ERROR("%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
@@ -490,7 +490,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
// make sure there are no duplicate tensor names
|
||||
for (int64_t j = 0; ok && j < i; ++j) {
|
||||
if (strcmp(info.t.name, ctx->info[j].t.name) == 0) {
|
||||
fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
|
||||
GGML_LOG_ERROR("%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
@@ -505,7 +505,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
uint32_t n_dims = -1;
|
||||
ok = ok && gr.read(n_dims);
|
||||
if (n_dims > GGML_MAX_DIMS) {
|
||||
fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
|
||||
GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
|
||||
__func__, info.t.name, n_dims, GGML_MAX_DIMS);
|
||||
ok = false;
|
||||
break;
|
||||
@@ -518,7 +518,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
// check that all ne are non-negative
|
||||
if (info.t.ne[j] < 0) {
|
||||
fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
|
||||
GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
|
||||
__func__, info.t.name, j, info.t.ne[j]);
|
||||
ok = false;
|
||||
break;
|
||||
@@ -530,7 +530,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
(INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
|
||||
(INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) {
|
||||
|
||||
fprintf(stderr, "%s: total number of elements in tensor '%s' with shape "
|
||||
GGML_LOG_ERROR("%s: total number of elements in tensor '%s' with shape "
|
||||
"(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n",
|
||||
__func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX);
|
||||
ok = false;
|
||||
@@ -547,7 +547,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
// check that tensor type is within defined range
|
||||
if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n",
|
||||
GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n",
|
||||
__func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
|
||||
ok = false;
|
||||
break;
|
||||
@@ -557,7 +557,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
// check that row size is divisible by block size
|
||||
if (blck_size == 0 || info.t.ne[0] % blck_size != 0) {
|
||||
fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
|
||||
GGML_LOG_ERROR("%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
|
||||
"not a multiple of block size (%" PRId64 ")\n",
|
||||
__func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size);
|
||||
ok = false;
|
||||
@@ -582,7 +582,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read tensor info\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read tensor info\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -590,7 +590,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
// we require the data section to be aligned, so take into account any padding
|
||||
if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
|
||||
fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -604,9 +604,9 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
for (size_t i = 0; i < ctx->info.size(); ++i) {
|
||||
const gguf_tensor_info & ti = ctx->info[i];
|
||||
if (ti.offset != ctx->size) {
|
||||
fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
|
||||
GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
|
||||
__func__, ti.t.name, ti.offset, ctx->size);
|
||||
fprintf(stderr, "%s: failed to read tensor data\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -634,7 +634,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
*params.ctx = ggml_init(pdata);
|
||||
if (*params.ctx == nullptr) {
|
||||
fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to initialize ggml context for storing tensors\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -656,7 +656,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
ok = ok && gr.read(data->data, ctx->size);
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read tensor data binary blob\n", __func__);
|
||||
ggml_free(ctx_data);
|
||||
*params.ctx = nullptr;
|
||||
gguf_free(ctx);
|
||||
@@ -689,7 +689,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to create tensors\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to create tensors\n", __func__);
|
||||
ggml_free(ctx_data);
|
||||
*params.ctx = nullptr;
|
||||
gguf_free(ctx);
|
||||
@@ -706,7 +706,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
FILE * file = ggml_fopen(fname, "rb");
|
||||
|
||||
if (!file) {
|
||||
fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname);
|
||||
GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -1305,7 +1305,7 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo
|
||||
FILE * file = ggml_fopen(fname, "wb");
|
||||
|
||||
if (!file) {
|
||||
fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
|
||||
GGML_LOG_ERROR("%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -482,14 +482,15 @@ class MODEL_TENSOR(IntEnum):
|
||||
V_ENC_EMBD_CLS = auto()
|
||||
V_ENC_EMBD_PATCH = auto()
|
||||
V_ENC_EMBD_POS = auto()
|
||||
V_ENC_INPUT_NORM = auto()
|
||||
V_ENC_ATTN_Q = auto()
|
||||
V_ENC_ATTN_Q_NORM = auto()
|
||||
V_ENC_ATTN_K = auto()
|
||||
V_ENC_ATTN_K_NORM = auto()
|
||||
V_ENC_ATTN_V = auto()
|
||||
V_ENC_INPUT_NORM = auto()
|
||||
V_ENC_OUTPUT = auto()
|
||||
V_ENC_OUTPUT_NORM = auto()
|
||||
V_ENC_ATTN_O = auto()
|
||||
V_ENC_ATTN_O_NORM = auto()
|
||||
V_ENC_POST_ATTN_NORM = auto()
|
||||
V_ENC_FFN_UP = auto()
|
||||
V_ENC_FFN_GATE = auto()
|
||||
V_ENC_FFN_DOWN = auto()
|
||||
@@ -749,8 +750,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1",
|
||||
MODEL_TENSOR.V_ENC_OUTPUT: "v.blk.{bid}.attn_out",
|
||||
MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.blk.{bid}.ln2",
|
||||
MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out",
|
||||
MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm",
|
||||
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2",
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
|
||||
@@ -785,14 +787,15 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS,
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH,
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS,
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM,
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q,
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q_NORM,
|
||||
MODEL_TENSOR.V_ENC_ATTN_K,
|
||||
MODEL_TENSOR.V_ENC_ATTN_K_NORM,
|
||||
MODEL_TENSOR.V_ENC_ATTN_V,
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM,
|
||||
MODEL_TENSOR.V_ENC_OUTPUT,
|
||||
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
|
||||
MODEL_TENSOR.V_ENC_ATTN_O,
|
||||
MODEL_TENSOR.V_ENC_ATTN_O_NORM,
|
||||
MODEL_TENSOR.V_ENC_POST_ATTN_NORM,
|
||||
MODEL_TENSOR.V_ENC_FFN_UP,
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE,
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN,
|
||||
@@ -1905,6 +1908,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.CHAMELEON: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
@@ -2177,6 +2183,7 @@ class VisionProjectorType:
|
||||
GEMMA3 = "gemma3"
|
||||
IDEFICS3 = "idefics3"
|
||||
PIXTRAL = "pixtral"
|
||||
LLAMA4 = "llama4"
|
||||
QWEN2VL = "qwen2vl_merger"
|
||||
QWEN25VL = "qwen2.5vl_merger"
|
||||
INTERNVL = "internvl"
|
||||
|
||||
@@ -823,6 +823,7 @@ class GGUFEditorWindow(QMainWindow):
|
||||
self.modified = False
|
||||
self.metadata_changes = {} # Store changes to apply when saving
|
||||
self.metadata_to_remove = set() # Store keys to remove when saving
|
||||
self.on_metadata_changed_is_connected = False
|
||||
|
||||
self.setup_ui()
|
||||
|
||||
@@ -941,9 +942,11 @@ class GGUFEditorWindow(QMainWindow):
|
||||
return
|
||||
|
||||
# Disconnect to prevent triggering during loading
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore')
|
||||
self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
|
||||
if self.on_metadata_changed_is_connected:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore')
|
||||
self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
|
||||
self.on_metadata_changed_is_connected = False
|
||||
|
||||
for i, (key, field) in enumerate(self.reader.fields.items()):
|
||||
self.metadata_table.insertRow(i)
|
||||
@@ -1021,6 +1024,7 @@ class GGUFEditorWindow(QMainWindow):
|
||||
|
||||
# Reconnect after loading
|
||||
self.metadata_table.itemChanged.connect(self.on_metadata_changed)
|
||||
self.on_metadata_changed_is_connected = True
|
||||
|
||||
def extract_array_values(self, field: ReaderField) -> list:
|
||||
"""Extract all values from an array field."""
|
||||
|
||||
@@ -68,7 +68,7 @@ class TensorNameMap:
|
||||
"output_layer", # chatglm
|
||||
"head", # rwkv
|
||||
"head.out", # wavtokenizer
|
||||
"language_model.lm_head", # llama4
|
||||
"lm_head", # llama4
|
||||
),
|
||||
|
||||
# Output norm
|
||||
@@ -91,7 +91,7 @@ class TensorNameMap:
|
||||
"rwkv.ln_out", # rwkv6
|
||||
"model.ln_out", # rwkv7
|
||||
"backbone.final_layer_norm", # wavtokenizer
|
||||
"language_model.model.norm", # llama4
|
||||
"model.norm", # llama4
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
@@ -133,7 +133,7 @@ class TensorNameMap:
|
||||
"transformer.layers.{bid}.attn_norm", # openelm
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv6
|
||||
"model.layers.{bid}.ln1", # rwkv7
|
||||
"language_model.model.layers.{bid}.input_layernorm", # llama4
|
||||
"model.layers.{bid}.input_layernorm", # llama4
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
@@ -173,7 +173,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.attention.wq", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
|
||||
"transformer.h.{bid}.attn.attention.q_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.q_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention key
|
||||
@@ -188,7 +188,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.attention.wk", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
|
||||
"transformer.h.{bid}.attn.attention.k_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.k_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention value
|
||||
@@ -202,7 +202,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.attention.wv", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
|
||||
"transformer.h.{bid}.attn.attention.v_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.v_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention output
|
||||
@@ -229,7 +229,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.self_attention.dense", # chatglm
|
||||
"transformer.layers.{bid}.attn.out_proj", # openelm
|
||||
"transformer.h.{bid}.attn.attention.out_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
@@ -268,7 +268,7 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama4
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
@@ -289,7 +289,7 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.router", # Grok
|
||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||
"language_model.model.layers.{bid}.feed_forward.router", # llama4
|
||||
"model.layers.{bid}.feed_forward.router", # llama4
|
||||
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
||||
),
|
||||
|
||||
@@ -329,7 +329,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.up_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
@@ -338,14 +338,14 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
|
||||
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
),
|
||||
|
||||
# AWQ-activation gate
|
||||
@@ -366,22 +366,22 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
|
||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
||||
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
),
|
||||
|
||||
# Feed-forward down
|
||||
@@ -410,7 +410,7 @@ class TensorNameMap:
|
||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.down_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
@@ -420,14 +420,15 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
|
||||
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
@@ -901,10 +902,12 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_FC: (
|
||||
"model.connector.modality_projection.proj", # SmolVLM
|
||||
"multi_modal_projector.linear_1", # llama 4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_MLP: (
|
||||
"model.mm_projector.mlp.mlp.{bid}",
|
||||
"vision_model.vision_adapter.mlp.fc{bid}", # llama 4
|
||||
"mlp1.{bid}", # InternVL
|
||||
),
|
||||
|
||||
@@ -914,6 +917,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS: (
|
||||
"vision_tower.vision_model.embeddings.class_embedding",
|
||||
"vision_model.class_embedding", # llama 4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
|
||||
@@ -921,6 +925,7 @@ class TensorNameMap:
|
||||
"vpm.embeddings.patch_embedding",
|
||||
"model.vision_model.embeddings.patch_embedding", # SmolVLM
|
||||
"vision_tower.patch_conv", # pixtral
|
||||
"vision_model.patch_embedding.linear", # llama 4
|
||||
"visual.patch_embed.proj", # qwen2vl
|
||||
),
|
||||
|
||||
@@ -928,12 +933,14 @@ class TensorNameMap:
|
||||
"vision_tower.vision_model.embeddings.position_embedding",
|
||||
"vpm.embeddings.position_embedding",
|
||||
"model.vision_model.embeddings.position_embedding", # SmolVLM
|
||||
"vision_model.positional_embedding_vlm", # llama 4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
|
||||
"vpm.encoder.layers.{bid}.self_attn.q_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
|
||||
"vision_model.model.layers.{bid}.self_attn.q_proj", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
|
||||
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
|
||||
),
|
||||
@@ -946,6 +953,7 @@ class TensorNameMap:
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
|
||||
"vpm.encoder.layers.{bid}.self_attn.k_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
|
||||
"vision_model.model.layers.{bid}.self_attn.k_proj", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
|
||||
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
|
||||
),
|
||||
@@ -958,6 +966,7 @@ class TensorNameMap:
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
|
||||
"vpm.encoder.layers.{bid}.self_attn.v_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
|
||||
"vision_model.model.layers.{bid}.self_attn.v_proj", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
|
||||
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
|
||||
),
|
||||
@@ -968,23 +977,26 @@ class TensorNameMap:
|
||||
"vpm.encoder.layers.{bid}.layer_norm1",
|
||||
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
|
||||
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
|
||||
"vision_model.model.layers.{bid}.input_layernorm", # llama4
|
||||
"visual.blocks.{bid}.norm1", # qwen2vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_OUTPUT: (
|
||||
MODEL_TENSOR.V_ENC_ATTN_O: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL
|
||||
"vpm.encoder.layers.{bid}.self_attn.out_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
|
||||
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
|
||||
"visual.blocks.{bid}.attn.proj", # qwen2vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
|
||||
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL
|
||||
"vpm.encoder.layers.{bid}.layer_norm2",
|
||||
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
|
||||
"vision_model.model.layers.{bid}.post_attention_layernorm", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
|
||||
"visual.blocks.{bid}.norm2", # qwen2vl
|
||||
),
|
||||
@@ -994,6 +1006,7 @@ class TensorNameMap:
|
||||
"vpm.encoder.layers.{bid}.mlp.fc1",
|
||||
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
|
||||
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
|
||||
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
|
||||
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
|
||||
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
|
||||
),
|
||||
@@ -1008,6 +1021,7 @@ class TensorNameMap:
|
||||
"vpm.encoder.layers.{bid}.mlp.fc2",
|
||||
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
|
||||
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
|
||||
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
|
||||
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
|
||||
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
|
||||
),
|
||||
@@ -1023,11 +1037,13 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.V_PRE_NORM: (
|
||||
"vision_tower.vision_model.pre_layrnorm",
|
||||
"vision_tower.ln_pre", # pixtral
|
||||
"vision_model.layernorm_pre", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_POST_NORM: (
|
||||
"vision_tower.vision_model.post_layernorm",
|
||||
"model.vision_model.post_layernorm", # SmolVLM
|
||||
"vision_model.layernorm_post", # llama4
|
||||
"visual.merger.ln_q", # qwen2vl
|
||||
),
|
||||
|
||||
|
||||
180
include/llama.h
180
include/llama.h
@@ -4,6 +4,7 @@
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-opt.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
@@ -344,7 +345,7 @@ extern "C" {
|
||||
float yarn_beta_fast; // YaRN low correction dim
|
||||
float yarn_beta_slow; // YaRN high correction dim
|
||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
|
||||
float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
@@ -360,10 +361,11 @@ extern "C" {
|
||||
|
||||
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||
bool embeddings; // if true, extract embeddings (together with logits)
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||
bool no_perf; // whether to measure performance timings
|
||||
bool op_offload; // whether to offload host tensor operations to device
|
||||
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // use flash attention [EXPERIMENTAL]
|
||||
bool no_perf; // measure performance timings
|
||||
bool op_offload; // offload host tensor operations to device
|
||||
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
@@ -445,6 +447,10 @@ extern "C" {
|
||||
size_t n_paths,
|
||||
struct llama_model_params params);
|
||||
|
||||
LLAMA_API void llama_model_save_to_file(
|
||||
const struct llama_model * model,
|
||||
const char * path_model);
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
|
||||
"use llama_model_free instead");
|
||||
|
||||
@@ -602,72 +608,13 @@ extern "C" {
|
||||
// KV cache
|
||||
//
|
||||
|
||||
// TODO: start using struct llama_kv_cache
|
||||
|
||||
// Information associated with an individual cell in the KV cache view.
|
||||
struct llama_kv_cache_view_cell {
|
||||
// The position for this cell. Takes KV cache shifts into account.
|
||||
// May be negative if the cell is not populated.
|
||||
llama_pos pos;
|
||||
};
|
||||
|
||||
// An updateable view of the KV cache.
|
||||
struct llama_kv_cache_view {
|
||||
// Number of KV cache cells. This will be the same as the context size.
|
||||
int32_t n_cells;
|
||||
|
||||
// Maximum number of sequences that can exist in a cell. It's not an error
|
||||
// if there are more sequences in a cell than this value, however they will
|
||||
// not be visible in the view cells_sequences.
|
||||
int32_t n_seq_max;
|
||||
|
||||
// Number of tokens in the cache. For example, if there are two populated
|
||||
// cells, the first with 1 sequence id in it and the second with 2 sequence
|
||||
// ids then you'll have 3 tokens.
|
||||
int32_t token_count;
|
||||
|
||||
// Number of populated cache cells.
|
||||
int32_t used_cells;
|
||||
|
||||
// Maximum contiguous empty slots in the cache.
|
||||
int32_t max_contiguous;
|
||||
|
||||
// Index to the start of the max_contiguous slot range. Can be negative
|
||||
// when cache is full.
|
||||
int32_t max_contiguous_idx;
|
||||
|
||||
// Information for an individual cell.
|
||||
struct llama_kv_cache_view_cell * cells;
|
||||
|
||||
// The sequences for each cell. There will be n_seq_max items per cell.
|
||||
llama_seq_id * cells_sequences;
|
||||
};
|
||||
|
||||
// Create an empty KV cache view. (use only for debugging purposes)
|
||||
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
|
||||
|
||||
// Free a KV cache view. (use only for debugging purposes)
|
||||
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
||||
|
||||
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
||||
// TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
|
||||
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
||||
|
||||
///
|
||||
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
|
||||
|
||||
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
||||
"use llama_kv_self_n_tokens instead");
|
||||
|
||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
|
||||
|
||||
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
|
||||
"use llama_kv_self_used_cells instead");
|
||||
|
||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||
LLAMA_API void llama_kv_self_clear(
|
||||
struct llama_context * ctx);
|
||||
@@ -725,10 +672,18 @@ extern "C" {
|
||||
llama_pos p1,
|
||||
int d);
|
||||
|
||||
// Returns the smallest position present in the KV cache for the specified sequence
|
||||
// This is typically non-zero only for SWA caches
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Returns the largest position present in the KV cache for the specified sequence
|
||||
// Return -1 if the sequence is empty
|
||||
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
@@ -742,61 +697,6 @@ extern "C" {
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
|
||||
struct llama_context * ctx),
|
||||
"use llama_kv_self_clear instead");
|
||||
|
||||
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"use llama_kv_self_seq_rm instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"use llama_kv_self_seq_cp instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"use llama_kv_self_seq_keep instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta),
|
||||
"use llama_kv_self_seq_add instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d),
|
||||
"use llama_kv_self_seq_div instead");
|
||||
|
||||
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"use llama_kv_self_seq_pos_max instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
|
||||
"use llama_kv_self_defrag instead");
|
||||
|
||||
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
|
||||
"use llama_kv_self_can_shift instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
|
||||
"use llama_kv_self_update instead");
|
||||
|
||||
|
||||
//
|
||||
// State / sessions
|
||||
//
|
||||
@@ -938,9 +838,12 @@ extern "C" {
|
||||
// Requires KV cache.
|
||||
// For encode-decoder contexts, processes the batch using the decoder.
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
// Upon non-zero return values, the KV cache state is restored to the state before this call
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// 2 - aborted
|
||||
// -1 - invalid input batch
|
||||
// < -1 - error
|
||||
LLAMA_API int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
@@ -1433,6 +1336,37 @@ extern "C" {
|
||||
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
|
||||
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
// function that returns whether or not a given tensor contains trainable parameters
|
||||
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
// always returns true
|
||||
LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
struct llama_opt_params {
|
||||
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
|
||||
|
||||
llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
|
||||
void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
|
||||
|
||||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||
};
|
||||
|
||||
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
|
||||
|
||||
LLAMA_API void llama_opt_epoch(
|
||||
struct llama_context * lctx,
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -7,6 +7,10 @@ import sys
|
||||
import os
|
||||
from glob import glob
|
||||
import sqlite3
|
||||
import json
|
||||
import csv
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
try:
|
||||
import git
|
||||
@@ -17,6 +21,28 @@ except ImportError as e:
|
||||
|
||||
logger = logging.getLogger("compare-llama-bench")
|
||||
|
||||
# All llama-bench SQL fields
|
||||
DB_FIELDS = [
|
||||
"build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename",
|
||||
"model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
|
||||
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
|
||||
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
|
||||
"defrag_thold",
|
||||
"use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth",
|
||||
"test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
|
||||
]
|
||||
|
||||
DB_TYPES = [
|
||||
"TEXT", "INTEGER", "TEXT", "TEXT", "TEXT", "TEXT",
|
||||
"TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT",
|
||||
"REAL",
|
||||
"INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "REAL", "REAL",
|
||||
]
|
||||
assert len(DB_FIELDS) == len(DB_TYPES)
|
||||
|
||||
# Properties by which to differentiate results per commit:
|
||||
KEY_PROPERTIES = [
|
||||
"cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type",
|
||||
@@ -42,7 +68,7 @@ DEFAULT_HIDE = ["model_filename"] # Always hide these properties by default.
|
||||
GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables.
|
||||
MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"}
|
||||
|
||||
DESCRIPTION = """Creates tables from llama-bench data written to an SQLite database. Example usage (Linux):
|
||||
DESCRIPTION = """Creates tables from llama-bench data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux):
|
||||
|
||||
$ git checkout master
|
||||
$ make clean && make llama-bench
|
||||
@@ -70,12 +96,13 @@ help_c = (
|
||||
)
|
||||
parser.add_argument("-c", "--compare", help=help_c)
|
||||
help_i = (
|
||||
"Input SQLite file for comparing commits. "
|
||||
"JSON/JSONL/SQLite/CSV files for comparing commits. "
|
||||
"Specify multiple times to use multiple input files (JSON/CSV only). "
|
||||
"Defaults to 'llama-bench.sqlite' in the current working directory. "
|
||||
"If no such file is found and there is exactly one .sqlite file in the current directory, "
|
||||
"that file is instead used as input."
|
||||
)
|
||||
parser.add_argument("-i", "--input", help=help_i)
|
||||
parser.add_argument("-i", "--input", action="append", help=help_i)
|
||||
help_o = (
|
||||
"Output format for the table. "
|
||||
"Defaults to 'pipe' (GitHub compatible). "
|
||||
@@ -86,7 +113,7 @@ parser.add_argument("-o", "--output", help=help_o, default="pipe")
|
||||
help_s = (
|
||||
"Columns to add to the table. "
|
||||
"Accepts a comma-separated list of values. "
|
||||
f"Legal values: {', '.join(KEY_PROPERTIES[:-2])}. "
|
||||
f"Legal values: {', '.join(KEY_PROPERTIES[:-3])}. "
|
||||
"Defaults to model name (model_type) and CPU and/or GPU name (cpu_info, gpu_info) "
|
||||
"plus any column where not all data points are the same. "
|
||||
"If the columns are manually specified, then the results for each unique combination of the "
|
||||
@@ -110,119 +137,321 @@ if unknown_args:
|
||||
sys.exit(1)
|
||||
|
||||
input_file = known_args.input
|
||||
if input_file is None and os.path.exists("./llama-bench.sqlite"):
|
||||
input_file = "llama-bench.sqlite"
|
||||
if input_file is None:
|
||||
if not input_file and os.path.exists("./llama-bench.sqlite"):
|
||||
input_file = ["llama-bench.sqlite"]
|
||||
if not input_file:
|
||||
sqlite_files = glob("*.sqlite")
|
||||
if len(sqlite_files) == 1:
|
||||
input_file = sqlite_files[0]
|
||||
input_file = sqlite_files
|
||||
|
||||
if input_file is None:
|
||||
if not input_file:
|
||||
logger.error("Cannot find a suitable input file, please provide one.\n")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
connection = sqlite3.connect(input_file)
|
||||
cursor = connection.cursor()
|
||||
|
||||
build_len_min: int = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
|
||||
build_len_max: int = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
|
||||
class LlamaBenchData:
|
||||
repo: Optional[git.Repo]
|
||||
build_len_min: int
|
||||
build_len_max: int
|
||||
build_len: int = 8
|
||||
builds: list[str] = []
|
||||
check_keys = set(KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"])
|
||||
|
||||
if build_len_min != build_len_max:
|
||||
logger.warning(f"{input_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
|
||||
"Try purging the the database of old commits.")
|
||||
cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {build_len_min});")
|
||||
def __init__(self):
|
||||
try:
|
||||
self.repo = git.Repo(".", search_parent_directories=True)
|
||||
except git.InvalidGitRepositoryError:
|
||||
self.repo = None
|
||||
|
||||
build_len: int = build_len_min
|
||||
def _builds_init(self):
|
||||
self.build_len = self.build_len_min
|
||||
|
||||
builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
|
||||
builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
|
||||
|
||||
if not builds:
|
||||
raise RuntimeError(f"{input_file} does not contain any builds.")
|
||||
|
||||
try:
|
||||
repo = git.Repo(".", search_parent_directories=True)
|
||||
except git.InvalidGitRepositoryError:
|
||||
repo = None
|
||||
|
||||
|
||||
def find_parent_in_data(commit: git.Commit):
|
||||
"""Helper function to find the most recent parent measured in number of commits for which there is data."""
|
||||
heap: list[tuple[int, git.Commit]] = [(0, commit)]
|
||||
seen_hexsha8 = set()
|
||||
while heap:
|
||||
depth, current_commit = heapq.heappop(heap)
|
||||
current_hexsha8 = commit.hexsha[:build_len]
|
||||
if current_hexsha8 in builds:
|
||||
return current_hexsha8
|
||||
for parent in commit.parents:
|
||||
parent_hexsha8 = parent.hexsha[:build_len]
|
||||
if parent_hexsha8 not in seen_hexsha8:
|
||||
seen_hexsha8.add(parent_hexsha8)
|
||||
heapq.heappush(heap, (depth + 1, parent))
|
||||
return None
|
||||
|
||||
|
||||
def get_all_parent_hexsha8s(commit: git.Commit):
|
||||
"""Helper function to recursively get hexsha8 values for all parents of a commit."""
|
||||
unvisited = [commit]
|
||||
visited = []
|
||||
|
||||
while unvisited:
|
||||
current_commit = unvisited.pop(0)
|
||||
visited.append(current_commit.hexsha[:build_len])
|
||||
for parent in current_commit.parents:
|
||||
if parent.hexsha[:build_len] not in visited:
|
||||
unvisited.append(parent)
|
||||
|
||||
return visited
|
||||
|
||||
|
||||
def get_commit_name(hexsha8: str):
|
||||
"""Helper function to find a human-readable name for a commit if possible."""
|
||||
if repo is None:
|
||||
return hexsha8
|
||||
for h in repo.heads:
|
||||
if h.commit.hexsha[:build_len] == hexsha8:
|
||||
return h.name
|
||||
for t in repo.tags:
|
||||
if t.commit.hexsha[:build_len] == hexsha8:
|
||||
return t.name
|
||||
return hexsha8
|
||||
|
||||
|
||||
def get_commit_hexsha8(name: str):
|
||||
"""Helper function to search for a commit given a human-readable name."""
|
||||
if repo is None:
|
||||
def _check_keys(self, keys: set) -> Optional[set]:
|
||||
"""Private helper method that checks against required data keys and returns missing ones."""
|
||||
if not keys >= self.check_keys:
|
||||
return self.check_keys - keys
|
||||
return None
|
||||
for h in repo.heads:
|
||||
if h.name == name:
|
||||
return h.commit.hexsha[:build_len]
|
||||
for t in repo.tags:
|
||||
if t.name == name:
|
||||
return t.commit.hexsha[:build_len]
|
||||
for c in repo.iter_commits("--all"):
|
||||
if c.hexsha[:build_len] == name[:build_len]:
|
||||
return c.hexsha[:build_len]
|
||||
return None
|
||||
|
||||
def find_parent_in_data(self, commit: git.Commit) -> Optional[str]:
|
||||
"""Helper method to find the most recent parent measured in number of commits for which there is data."""
|
||||
heap: list[tuple[int, git.Commit]] = [(0, commit)]
|
||||
seen_hexsha8 = set()
|
||||
while heap:
|
||||
depth, current_commit = heapq.heappop(heap)
|
||||
current_hexsha8 = commit.hexsha[:self.build_len]
|
||||
if current_hexsha8 in self.builds:
|
||||
return current_hexsha8
|
||||
for parent in commit.parents:
|
||||
parent_hexsha8 = parent.hexsha[:self.build_len]
|
||||
if parent_hexsha8 not in seen_hexsha8:
|
||||
seen_hexsha8.add(parent_hexsha8)
|
||||
heapq.heappush(heap, (depth + 1, parent))
|
||||
return None
|
||||
|
||||
def get_all_parent_hexsha8s(self, commit: git.Commit) -> Sequence[str]:
|
||||
"""Helper method to recursively get hexsha8 values for all parents of a commit."""
|
||||
unvisited = [commit]
|
||||
visited = []
|
||||
|
||||
while unvisited:
|
||||
current_commit = unvisited.pop(0)
|
||||
visited.append(current_commit.hexsha[:self.build_len])
|
||||
for parent in current_commit.parents:
|
||||
if parent.hexsha[:self.build_len] not in visited:
|
||||
unvisited.append(parent)
|
||||
|
||||
return visited
|
||||
|
||||
def get_commit_name(self, hexsha8: str) -> str:
|
||||
"""Helper method to find a human-readable name for a commit if possible."""
|
||||
if self.repo is None:
|
||||
return hexsha8
|
||||
for h in self.repo.heads:
|
||||
if h.commit.hexsha[:self.build_len] == hexsha8:
|
||||
return h.name
|
||||
for t in self.repo.tags:
|
||||
if t.commit.hexsha[:self.build_len] == hexsha8:
|
||||
return t.name
|
||||
return hexsha8
|
||||
|
||||
def get_commit_hexsha8(self, name: str) -> Optional[str]:
|
||||
"""Helper method to search for a commit given a human-readable name."""
|
||||
if self.repo is None:
|
||||
return None
|
||||
for h in self.repo.heads:
|
||||
if h.name == name:
|
||||
return h.commit.hexsha[:self.build_len]
|
||||
for t in self.repo.tags:
|
||||
if t.name == name:
|
||||
return t.commit.hexsha[:self.build_len]
|
||||
for c in self.repo.iter_commits("--all"):
|
||||
if c.hexsha[:self.build_len] == name[:self.build_len]:
|
||||
return c.hexsha[:self.build_len]
|
||||
return None
|
||||
|
||||
def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
|
||||
"""Helper method that gets rows of (build_commit, test_time) sorted by the latter."""
|
||||
return []
|
||||
|
||||
def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
|
||||
"""
|
||||
Helper method that gets table rows for some list of properties.
|
||||
Rows are created by combining those where all provided properties are equal.
|
||||
The resulting rows are then grouped by the provided properties and the t/s values are averaged.
|
||||
The returned rows are unique in terms of property combinations.
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
class LlamaBenchDataSQLite3(LlamaBenchData):
|
||||
connection: sqlite3.Connection
|
||||
cursor: sqlite3.Cursor
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.connection = sqlite3.connect(":memory:")
|
||||
self.cursor = self.connection.cursor()
|
||||
self.cursor.execute(f"CREATE TABLE test({', '.join(' '.join(x) for x in zip(DB_FIELDS, DB_TYPES))});")
|
||||
|
||||
def _builds_init(self):
|
||||
if self.connection:
|
||||
self.build_len_min = self.cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
|
||||
self.build_len_max = self.cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
|
||||
|
||||
if self.build_len_min != self.build_len_max:
|
||||
logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
|
||||
"Try purging the the database of old commits.")
|
||||
self.cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});")
|
||||
|
||||
builds = self.cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
|
||||
self.builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
|
||||
super()._builds_init()
|
||||
|
||||
def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
|
||||
data = self.cursor.execute(
|
||||
"SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall()
|
||||
return reversed(data) if reverse else data
|
||||
|
||||
def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
|
||||
select_string = ", ".join(
|
||||
[f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
|
||||
equal_string = " AND ".join(
|
||||
[f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [
|
||||
f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
|
||||
)
|
||||
group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
|
||||
query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} "
|
||||
f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
|
||||
return self.cursor.execute(query).fetchall()
|
||||
|
||||
|
||||
class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
|
||||
def __init__(self, data_file: str):
|
||||
super().__init__()
|
||||
|
||||
self.connection.close()
|
||||
self.connection = sqlite3.connect(data_file)
|
||||
self.cursor = self.connection.cursor()
|
||||
self._builds_init()
|
||||
|
||||
@staticmethod
|
||||
def valid_format(data_file: str) -> bool:
|
||||
connection = sqlite3.connect(data_file)
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0:
|
||||
raise sqlite3.DatabaseError("The provided input file does not exist or is empty.")
|
||||
except sqlite3.DatabaseError as e:
|
||||
logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e)
|
||||
cursor = None
|
||||
|
||||
connection.close()
|
||||
return True if cursor else False
|
||||
|
||||
|
||||
class LlamaBenchDataJSONL(LlamaBenchDataSQLite3):
|
||||
def __init__(self, data_file: str):
|
||||
super().__init__()
|
||||
|
||||
with open(data_file, "r", encoding="utf-8") as fp:
|
||||
for i, line in enumerate(fp):
|
||||
parsed = json.loads(line)
|
||||
|
||||
for k in parsed.keys() - set(DB_FIELDS):
|
||||
del parsed[k]
|
||||
|
||||
if (missing_keys := self._check_keys(parsed.keys())):
|
||||
raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
|
||||
|
||||
self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
|
||||
|
||||
self._builds_init()
|
||||
|
||||
@staticmethod
|
||||
def valid_format(data_file: str) -> bool:
|
||||
try:
|
||||
with open(data_file, "r", encoding="utf-8") as fp:
|
||||
for line in fp:
|
||||
json.loads(line)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f'"{data_file}" is not a valid JSONL file.', exc_info=e)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class LlamaBenchDataJSON(LlamaBenchDataSQLite3):
|
||||
def __init__(self, data_files: list[str]):
|
||||
super().__init__()
|
||||
|
||||
for data_file in data_files:
|
||||
with open(data_file, "r", encoding="utf-8") as fp:
|
||||
parsed = json.load(fp)
|
||||
|
||||
for i, entry in enumerate(parsed):
|
||||
for k in entry.keys() - set(DB_FIELDS):
|
||||
del entry[k]
|
||||
|
||||
if (missing_keys := self._check_keys(entry.keys())):
|
||||
raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}")
|
||||
|
||||
self.cursor.execute(f"INSERT INTO test({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values()))
|
||||
|
||||
self._builds_init()
|
||||
|
||||
@staticmethod
|
||||
def valid_format(data_files: list[str]) -> bool:
|
||||
if not data_files:
|
||||
return False
|
||||
|
||||
for data_file in data_files:
|
||||
try:
|
||||
with open(data_file, "r", encoding="utf-8") as fp:
|
||||
json.load(fp)
|
||||
except Exception as e:
|
||||
logger.debug(f'"{data_file}" is not a valid JSON file.', exc_info=e)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class LlamaBenchDataCSV(LlamaBenchDataSQLite3):
|
||||
def __init__(self, data_files: list[str]):
|
||||
super().__init__()
|
||||
|
||||
for data_file in data_files:
|
||||
with open(data_file, "r", encoding="utf-8") as fp:
|
||||
for i, parsed in enumerate(csv.DictReader(fp)):
|
||||
keys = set(parsed.keys())
|
||||
|
||||
for k in keys - set(DB_FIELDS):
|
||||
del parsed[k]
|
||||
|
||||
if (missing_keys := self._check_keys(keys)):
|
||||
raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
|
||||
|
||||
self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
|
||||
|
||||
self._builds_init()
|
||||
|
||||
@staticmethod
|
||||
def valid_format(data_files: list[str]) -> bool:
|
||||
if not data_files:
|
||||
return False
|
||||
|
||||
for data_file in data_files:
|
||||
try:
|
||||
with open(data_file, "r", encoding="utf-8") as fp:
|
||||
for parsed in csv.DictReader(fp):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f'"{data_file}" is not a valid CSV file.', exc_info=e)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
bench_data = None
|
||||
if len(input_file) == 1:
|
||||
if LlamaBenchDataSQLite3File.valid_format(input_file[0]):
|
||||
bench_data = LlamaBenchDataSQLite3File(input_file[0])
|
||||
elif LlamaBenchDataJSON.valid_format(input_file):
|
||||
bench_data = LlamaBenchDataJSON(input_file)
|
||||
elif LlamaBenchDataJSONL.valid_format(input_file[0]):
|
||||
bench_data = LlamaBenchDataJSONL(input_file[0])
|
||||
elif LlamaBenchDataCSV.valid_format(input_file):
|
||||
bench_data = LlamaBenchDataCSV(input_file)
|
||||
else:
|
||||
if LlamaBenchDataJSON.valid_format(input_file):
|
||||
bench_data = LlamaBenchDataJSON(input_file)
|
||||
elif LlamaBenchDataCSV.valid_format(input_file):
|
||||
bench_data = LlamaBenchDataCSV(input_file)
|
||||
|
||||
if not bench_data:
|
||||
raise RuntimeError("No valid (or some invalid) input files found.")
|
||||
|
||||
if not bench_data.builds:
|
||||
raise RuntimeError(f"{input_file} does not contain any builds.")
|
||||
|
||||
|
||||
hexsha8_baseline = name_baseline = None
|
||||
|
||||
# If the user specified a baseline, try to find a commit for it:
|
||||
if known_args.baseline is not None:
|
||||
if known_args.baseline in builds:
|
||||
if known_args.baseline in bench_data.builds:
|
||||
hexsha8_baseline = known_args.baseline
|
||||
if hexsha8_baseline is None:
|
||||
hexsha8_baseline = get_commit_hexsha8(known_args.baseline)
|
||||
hexsha8_baseline = bench_data.get_commit_hexsha8(known_args.baseline)
|
||||
name_baseline = known_args.baseline
|
||||
if hexsha8_baseline is None:
|
||||
logger.error(f"cannot find data for baseline={known_args.baseline}.")
|
||||
sys.exit(1)
|
||||
# Otherwise, search for the most recent parent of master for which there is data:
|
||||
elif repo is not None:
|
||||
hexsha8_baseline = find_parent_in_data(repo.heads.master.commit)
|
||||
elif bench_data.repo is not None:
|
||||
hexsha8_baseline = bench_data.find_parent_in_data(bench_data.repo.heads.master.commit)
|
||||
|
||||
if hexsha8_baseline is None:
|
||||
logger.error("No baseline was provided and did not find data for any master branch commits.\n")
|
||||
@@ -235,27 +464,25 @@ else:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
name_baseline = get_commit_name(hexsha8_baseline)
|
||||
name_baseline = bench_data.get_commit_name(hexsha8_baseline)
|
||||
|
||||
hexsha8_compare = name_compare = None
|
||||
|
||||
# If the user has specified a compare value, try to find a corresponding commit:
|
||||
if known_args.compare is not None:
|
||||
if known_args.compare in builds:
|
||||
if known_args.compare in bench_data.builds:
|
||||
hexsha8_compare = known_args.compare
|
||||
if hexsha8_compare is None:
|
||||
hexsha8_compare = get_commit_hexsha8(known_args.compare)
|
||||
hexsha8_compare = bench_data.get_commit_hexsha8(known_args.compare)
|
||||
name_compare = known_args.compare
|
||||
if hexsha8_compare is None:
|
||||
logger.error(f"cannot find data for compare={known_args.compare}.")
|
||||
sys.exit(1)
|
||||
# Otherwise, search for the commit for llama-bench was most recently run
|
||||
# and that is not a parent of master:
|
||||
elif repo is not None:
|
||||
hexsha8s_master = get_all_parent_hexsha8s(repo.heads.master.commit)
|
||||
builds_timestamp = cursor.execute(
|
||||
"SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall()
|
||||
for (hexsha8, _) in reversed(builds_timestamp):
|
||||
elif bench_data.repo is not None:
|
||||
hexsha8s_master = bench_data.get_all_parent_hexsha8s(bench_data.repo.heads.master.commit)
|
||||
for (hexsha8, _) in bench_data.builds_timestamp(reverse=True):
|
||||
if hexsha8 not in hexsha8s_master:
|
||||
hexsha8_compare = hexsha8
|
||||
break
|
||||
@@ -270,26 +497,7 @@ else:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
name_compare = get_commit_name(hexsha8_compare)
|
||||
|
||||
|
||||
def get_rows(properties):
|
||||
"""
|
||||
Helper function that gets table rows for some list of properties.
|
||||
Rows are created by combining those where all provided properties are equal.
|
||||
The resulting rows are then grouped by the provided properties and the t/s values are averaged.
|
||||
The returned rows are unique in terms of property combinations.
|
||||
"""
|
||||
select_string = ", ".join(
|
||||
[f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
|
||||
equal_string = " AND ".join(
|
||||
[f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [
|
||||
f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
|
||||
)
|
||||
group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
|
||||
query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} "
|
||||
f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
|
||||
return cursor.execute(query).fetchall()
|
||||
name_compare = bench_data.get_commit_name(hexsha8_compare)
|
||||
|
||||
|
||||
# If the user provided columns to group the results by, use them:
|
||||
@@ -297,16 +505,16 @@ if known_args.show is not None:
|
||||
show = known_args.show.split(",")
|
||||
unknown_cols = []
|
||||
for prop in show:
|
||||
if prop not in KEY_PROPERTIES[:-2]: # Last two values are n_prompt, n_gen.
|
||||
if prop not in KEY_PROPERTIES[:-3]: # Last three values are n_prompt, n_gen, n_depth.
|
||||
unknown_cols.append(prop)
|
||||
if unknown_cols:
|
||||
logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}")
|
||||
parser.print_usage()
|
||||
sys.exit(1)
|
||||
rows_show = get_rows(show)
|
||||
rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
|
||||
# Otherwise, select those columns where the values are not all the same:
|
||||
else:
|
||||
rows_full = get_rows(KEY_PROPERTIES)
|
||||
rows_full = bench_data.get_rows(KEY_PROPERTIES, hexsha8_baseline, hexsha8_compare)
|
||||
properties_different = []
|
||||
for i, kp_i in enumerate(KEY_PROPERTIES):
|
||||
if kp_i in DEFAULT_SHOW or kp_i in ["n_prompt", "n_gen", "n_depth"]:
|
||||
@@ -336,7 +544,7 @@ else:
|
||||
show.remove(prop)
|
||||
except ValueError:
|
||||
pass
|
||||
rows_show = get_rows(show)
|
||||
rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
|
||||
|
||||
if not rows_show:
|
||||
logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n")
|
||||
|
||||
@@ -1 +1 @@
|
||||
b59bddafe278877dfa22a80e53a637513862babb
|
||||
7c06c10c532a6cda913c17fc56341e8880ae341d
|
||||
|
||||
@@ -23,6 +23,7 @@ add_library(llama
|
||||
llama-memory.cpp
|
||||
llama-mmap.cpp
|
||||
llama-model-loader.cpp
|
||||
llama-model-saver.cpp
|
||||
llama-model.cpp
|
||||
llama-quant.cpp
|
||||
llama-sampling.cpp
|
||||
|
||||
@@ -1481,6 +1481,9 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -93,6 +93,7 @@ llama_context::llama_context(
|
||||
}
|
||||
|
||||
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
||||
|
||||
cparams.op_offload = params.op_offload;
|
||||
|
||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
@@ -176,8 +177,9 @@ llama_context::llama_context(
|
||||
// init the memory module
|
||||
if (!hparams.vocab_only) {
|
||||
llama_memory_params params_mem = {
|
||||
/*.type_k =*/ params.type_k,
|
||||
/*.type_v =*/ params.type_v,
|
||||
/*.type_k =*/ params.type_k,
|
||||
/*.type_v =*/ params.type_v,
|
||||
/*.swa_full =*/ params.swa_full,
|
||||
};
|
||||
|
||||
memory.reset(model.create_memory(params_mem, cparams));
|
||||
@@ -359,7 +361,9 @@ llama_context::llama_context(
|
||||
}
|
||||
}
|
||||
|
||||
llama_context::~llama_context() = default;
|
||||
llama_context::~llama_context() {
|
||||
ggml_opt_free(opt_ctx);
|
||||
}
|
||||
|
||||
void llama_context::synchronize() {
|
||||
ggml_backend_sched_synchronize(sched.get());
|
||||
@@ -945,8 +949,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
|
||||
// find KV slot
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -1702,10 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->state_write(io);
|
||||
if (kv_self != nullptr) {
|
||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||
kv_self->state_write(io);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
@@ -1788,10 +1792,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
if (memory) {
|
||||
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
||||
|
||||
kv_self->state_read(io);
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->state_read(io);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
@@ -1799,9 +1806,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||
GGML_UNUSED(seq_id);
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
if (memory) {
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->state_write(io, seq_id);
|
||||
kv_self->state_write(io, seq_id);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
@@ -1809,9 +1818,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
||||
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
GGML_UNUSED(seq_id);
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
if (memory) {
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->state_read(io, seq_id);
|
||||
kv_self->state_read(io, seq_id);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
@@ -1839,6 +1850,215 @@ void llama_context::perf_reset() {
|
||||
t_p_eval_us = n_p_eval = 0;
|
||||
}
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
|
||||
if (!tensor || tensor->type != GGML_TYPE_F32) {
|
||||
return;
|
||||
}
|
||||
if (!param_filter(tensor, userdata)) {
|
||||
return;
|
||||
}
|
||||
if (strcmp(tensor->name, "token_embd.weight") == 0) {
|
||||
return; // FIXME
|
||||
}
|
||||
if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
|
||||
return; // FIXME
|
||||
}
|
||||
ggml_set_param(tensor);
|
||||
}
|
||||
|
||||
void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
|
||||
GGML_ASSERT(!opt_ctx);
|
||||
model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
|
||||
const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
|
||||
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
||||
GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
|
||||
GGML_ASSERT(n_batch % n_ubatch == 0);
|
||||
|
||||
ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
|
||||
opt_params.opt_period = n_batch / n_ubatch;
|
||||
opt_params.get_opt_pars = lopt_params.get_opt_pars;
|
||||
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
|
||||
|
||||
opt_ctx = ggml_opt_init(opt_params);
|
||||
|
||||
llama_opt_param_filter param_filter = lopt_params.param_filter;
|
||||
void * param_filter_ud = lopt_params.param_filter_ud;
|
||||
|
||||
//llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
|
||||
llama_set_param(model->type_embd, param_filter, param_filter_ud);
|
||||
llama_set_param(model->pos_embd, param_filter, param_filter_ud);
|
||||
llama_set_param(model->tok_norm, param_filter, param_filter_ud);
|
||||
llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_norm, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_out, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
|
||||
|
||||
for (struct llama_layer & layer : model->layers) {
|
||||
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
|
||||
llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context::opt_epoch_iter(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result,
|
||||
const std::vector<llama_token> & tokens,
|
||||
const std::vector<llama_token> & labels_sparse,
|
||||
llama_batch & batch,
|
||||
ggml_opt_epoch_callback callback,
|
||||
bool train,
|
||||
int64_t idata_in_loop,
|
||||
int64_t ndata_in_loop,
|
||||
int64_t t_loop_start) {
|
||||
GGML_ASSERT(opt_ctx);
|
||||
const uint32_t n_ctx = llama_model_n_ctx_train(&model);
|
||||
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
||||
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->clear();
|
||||
llama_kv_cache_guard kv_guard(kv_self);
|
||||
|
||||
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
||||
batch.n_tokens = n_batch;
|
||||
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
|
||||
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
|
||||
batch.pos [pos_batch] = pos_ctx + pos_batch;
|
||||
batch.n_seq_id[pos_batch] = 1;
|
||||
batch.seq_id [pos_batch][0] = 0;
|
||||
batch.logits [pos_batch] = true;
|
||||
}
|
||||
|
||||
const auto n_tokens_all = batch.n_tokens;
|
||||
|
||||
n_queued_tokens += n_tokens_all;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
embd_seq.clear();
|
||||
|
||||
int64_t n_outputs_all = n_tokens_all;
|
||||
|
||||
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
||||
GGML_ABORT("TODO: handle this error");
|
||||
};
|
||||
|
||||
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
|
||||
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
||||
|
||||
n_outputs = ubatch.n_tokens;
|
||||
|
||||
// TODO: not sure if this is needed
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
||||
|
||||
GGML_ABORT("TODO: handle this error");
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
struct ggml_context * ctx_compute_opt;
|
||||
{
|
||||
const size_t size_gf = ggml_graph_size(gf);
|
||||
const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ size_meta,
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ctx_compute_opt = ggml_init(params);
|
||||
}
|
||||
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
||||
ggml_opt_alloc(opt_ctx, train);
|
||||
res->set_inputs(&ubatch);
|
||||
{
|
||||
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
||||
GGML_ASSERT(labels->ne[1] == n_ubatch);
|
||||
ggml_set_zero(labels);
|
||||
const float onef = 1.0f;
|
||||
for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
|
||||
const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
|
||||
GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
|
||||
ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
|
||||
}
|
||||
}
|
||||
ggml_opt_eval(opt_ctx, result);
|
||||
if (callback) {
|
||||
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
ggml_free(ctx_compute_opt);
|
||||
}
|
||||
}
|
||||
|
||||
kv_guard.commit();
|
||||
}
|
||||
|
||||
void llama_context::opt_epoch(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval) {
|
||||
const uint32_t n_ctx = this->n_ctx();
|
||||
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
|
||||
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
|
||||
const int64_t ndata = ggml_opt_dataset_ndata(dataset);
|
||||
|
||||
GGML_ASSERT(idata_split >= 0);
|
||||
GGML_ASSERT(idata_split <= ndata);
|
||||
|
||||
const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
|
||||
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
std::vector<llama_token> tokens(n_ctx);
|
||||
std::vector<llama_token> labels_sparse(n_ctx);
|
||||
|
||||
int64_t idata = 0;
|
||||
|
||||
int64_t t_loop_start = ggml_time_us();
|
||||
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
|
||||
for (; idata < idata_split; ++idata) {
|
||||
constexpr bool train = true;
|
||||
const int64_t idata_in_loop = idata*ubatch_per_ctx;
|
||||
|
||||
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
||||
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
|
||||
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
|
||||
t_loop_start = ggml_time_us();
|
||||
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
|
||||
for (; idata < ndata; ++idata) {
|
||||
constexpr bool train = false;
|
||||
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
|
||||
|
||||
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
||||
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
|
||||
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
@@ -1873,6 +2093,7 @@ llama_context_params llama_context_default_params() {
|
||||
/*.flash_attn =*/ false,
|
||||
/*.no_perf =*/ true,
|
||||
/*.op_offload =*/ true,
|
||||
/*.swa_full =*/ true,
|
||||
};
|
||||
|
||||
return result;
|
||||
@@ -2067,39 +2288,10 @@ int32_t llama_apply_adapter_cvec(
|
||||
return res ? 0 : -1;
|
||||
}
|
||||
|
||||
//
|
||||
// kv cache view
|
||||
//
|
||||
|
||||
llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
if (kv == nullptr) {
|
||||
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
||||
return {};
|
||||
}
|
||||
|
||||
return llama_kv_cache_view_init(*kv, n_seq_max);
|
||||
}
|
||||
|
||||
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
if (kv == nullptr) {
|
||||
LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
llama_kv_cache_view_update(view, kv);
|
||||
}
|
||||
|
||||
//
|
||||
// kv cache
|
||||
//
|
||||
|
||||
// deprecated
|
||||
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
||||
return llama_kv_self_n_tokens(ctx);
|
||||
}
|
||||
|
||||
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
@@ -2109,11 +2301,6 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||
return kv->get_n_tokens();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
|
||||
return llama_kv_self_used_cells(ctx);
|
||||
}
|
||||
|
||||
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
@@ -2123,11 +2310,6 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||
return kv->get_used_cells();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_clear(llama_context * ctx) {
|
||||
llama_kv_self_clear(ctx);
|
||||
}
|
||||
|
||||
void llama_kv_self_clear(llama_context * ctx) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
@@ -2137,15 +2319,6 @@ void llama_kv_self_clear(llama_context * ctx) {
|
||||
kv->clear();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_cache_seq_rm(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
bool llama_kv_self_seq_rm(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
@@ -2159,16 +2332,6 @@ bool llama_kv_self_seq_rm(
|
||||
return kv->seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_cp(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_cp(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
@@ -2183,13 +2346,6 @@ void llama_kv_self_seq_cp(
|
||||
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_keep(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id) {
|
||||
llama_kv_self_seq_keep(ctx, seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
@@ -2199,16 +2355,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||
kv->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_add(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_add(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
@@ -2223,16 +2369,6 @@ void llama_kv_self_seq_add(
|
||||
kv->seq_add(seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_div(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_div(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
@@ -2247,25 +2383,24 @@ void llama_kv_self_seq_div(
|
||||
kv->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_self_seq_pos_max(ctx, seq_id);
|
||||
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return kv->seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
return 0;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return kv->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_defrag(llama_context * ctx) {
|
||||
llama_kv_self_defrag(ctx);
|
||||
}
|
||||
|
||||
void llama_kv_self_defrag(llama_context * ctx) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
@@ -2276,11 +2411,6 @@ void llama_kv_self_defrag(llama_context * ctx) {
|
||||
kv->defrag_sched(-1.0f);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
||||
return llama_kv_self_can_shift(ctx);
|
||||
}
|
||||
|
||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
if (!kv) {
|
||||
@@ -2290,11 +2420,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||
return kv->get_can_shift();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_update(llama_context * ctx) {
|
||||
llama_kv_self_update(ctx);
|
||||
}
|
||||
|
||||
// llama state API
|
||||
|
||||
// deprecated
|
||||
@@ -2417,7 +2542,21 @@ int32_t llama_encode(
|
||||
int32_t llama_decode(
|
||||
llama_context * ctx,
|
||||
llama_batch batch) {
|
||||
const int ret = ctx->decode(batch);
|
||||
int ret = ctx->decode(batch);
|
||||
|
||||
// defrag and try again
|
||||
// TODO: distinguish return code when we are sure that even after defrag there is no space available
|
||||
if (ret == 1) {
|
||||
llama_kv_self_defrag(ctx);
|
||||
ret = ctx->decode(batch);
|
||||
|
||||
if (ret == 1) {
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
@@ -2457,3 +2596,34 @@ void llama_perf_context_print(const llama_context * ctx) {
|
||||
void llama_perf_context_reset(llama_context * ctx) {
|
||||
ctx->perf_reset();
|
||||
}
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
|
||||
GGML_UNUSED(tensor);
|
||||
GGML_UNUSED(userdata);
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
|
||||
ctx->opt_init(model, lopt_params);
|
||||
}
|
||||
|
||||
void llama_opt_epoch(
|
||||
struct llama_context * ctx,
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval) {
|
||||
ctx->opt_epoch(
|
||||
dataset,
|
||||
result_train,
|
||||
result_eval,
|
||||
idata_split,
|
||||
callback_train,
|
||||
callback_eval);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "llama-adapter.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-opt.h"
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
@@ -133,6 +134,32 @@ struct llama_context {
|
||||
llama_perf_context_data perf_get_data() const;
|
||||
void perf_reset();
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
||||
|
||||
void opt_epoch(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
void opt_epoch_iter(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result,
|
||||
const std::vector<llama_token> & tokens,
|
||||
const std::vector<llama_token> & labels_sparse,
|
||||
llama_batch & batch,
|
||||
ggml_opt_epoch_callback callback,
|
||||
bool train,
|
||||
int64_t idata_in_loop,
|
||||
int64_t ndata_in_loop,
|
||||
int64_t t_loop_start);
|
||||
|
||||
private:
|
||||
//
|
||||
// output
|
||||
@@ -212,6 +239,9 @@ private:
|
||||
|
||||
ggml_context_ptr ctx_compute;
|
||||
|
||||
// training
|
||||
ggml_opt_context_t opt_ctx = nullptr;
|
||||
|
||||
ggml_threadpool_t threadpool = nullptr;
|
||||
ggml_threadpool_t threadpool_batch = nullptr;
|
||||
|
||||
|
||||
@@ -9,33 +9,6 @@
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||
const int64_t max_distance = 128;
|
||||
|
||||
if (bidirectional) {
|
||||
n_buckets >>= 1;
|
||||
}
|
||||
|
||||
const int64_t max_exact = n_buckets >> 1;
|
||||
|
||||
int32_t relative_position = x - y;
|
||||
int32_t relative_bucket = 0;
|
||||
|
||||
if (bidirectional) {
|
||||
relative_bucket += (relative_position > 0) * n_buckets;
|
||||
relative_position = abs(relative_position);
|
||||
} else {
|
||||
relative_position = -std::min<int32_t>(relative_position, 0);
|
||||
}
|
||||
|
||||
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
||||
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
||||
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
||||
|
||||
return relative_bucket;
|
||||
}
|
||||
|
||||
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->token) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
@@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
||||
if (pos_bucket) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
|
||||
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
||||
|
||||
int32_t * data = (int32_t *) pos_bucket->data;
|
||||
|
||||
const int64_t n_kv = kv_self->n;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
kv_self->set_input_pos_bucket(pos_bucket, ubatch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -403,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask || self_kq_mask_swa) {
|
||||
const int64_t n_kv = kv_self->n;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
if (self_kq_mask) {
|
||||
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
}
|
||||
|
||||
float * data = nullptr;
|
||||
float * data_swa = nullptr;
|
||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
if (self_kq_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||
data = (float *) self_kq_mask->data;
|
||||
}
|
||||
|
||||
if (self_kq_mask_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
||||
data_swa = (float *) self_kq_mask_swa->data;
|
||||
}
|
||||
|
||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
||||
// Causal mask:
|
||||
// xxx-------
|
||||
// xxxx------
|
||||
// xxxxx-----
|
||||
// Non-causal mask:
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
// mask the token if:
|
||||
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
|
||||
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
|
||||
) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(kv_self->cells[i].pos - pos);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (data) {
|
||||
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
|
||||
if (data_swa) {
|
||||
if (hparams.n_attn_chunk) {
|
||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
} else {
|
||||
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
}
|
||||
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mask padded tokens
|
||||
if (data) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mask padded tokens
|
||||
if (data_swa) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (self_kq_mask_swa) {
|
||||
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -545,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||
n_layer (hparams.n_layer),
|
||||
n_rot (hparams.n_rot),
|
||||
n_ctx (cparams.n_ctx),
|
||||
n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
|
||||
n_head (hparams.n_head()),
|
||||
n_head_kv (hparams.n_head_kv()),
|
||||
n_embd_head_k (hparams.n_embd_head_k),
|
||||
@@ -971,6 +847,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
//cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_tokens = inp->tokens;
|
||||
|
||||
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
||||
|
||||
@@ -1152,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
const auto n_kv = kv_self->get_n();
|
||||
|
||||
auto & cur = inp->pos_bucket;
|
||||
|
||||
@@ -1187,16 +1064,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * v_mla,
|
||||
bool v_trans,
|
||||
float kq_scale) const {
|
||||
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
const bool v_trans = v->nb[1] > v->nb[2];
|
||||
|
||||
//const int64_t n_head = hparams.n_head(il);
|
||||
//const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
//const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
@@ -1335,17 +1208,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
||||
//cb(q, "q", il);
|
||||
|
||||
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
||||
//cb(k, "k", il);
|
||||
|
||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
||||
//cb(k, "v", il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = k_cur;
|
||||
ggml_tensor * v = v_cur;
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
@@ -1368,22 +1235,17 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
{
|
||||
GGML_ASSERT(hparams.n_swa_pattern == 1 && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
GGML_ASSERT(hparams.n_swa == 0 && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
const auto n_kv = kv_self->get_n();
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
if (hparams.n_swa_pattern > 1) {
|
||||
GGML_ASSERT(hparams.n_swa > 0);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
||||
@@ -1408,87 +1270,110 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const auto & n_ctx = cparams.n_ctx;
|
||||
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
const auto n_tokens = q_cur->ne[2];
|
||||
|
||||
const bool v_trans = !cparams.flash_attn;
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
const auto kv_head = kv_self->head;
|
||||
|
||||
GGML_ASSERT(kv_self->size == n_ctx);
|
||||
|
||||
ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
|
||||
//cb(k_cache_view, "k_cache_view", il);
|
||||
|
||||
// note: storing RoPE-ed version of K in the KV cache
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
|
||||
|
||||
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
|
||||
|
||||
ggml_tensor * v_cache_view = nullptr;
|
||||
|
||||
if (!v_trans) {
|
||||
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
|
||||
} else {
|
||||
// note: the V cache is transposed when not using flash attention
|
||||
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
|
||||
(kv_head)*ggml_element_size(kv_self->v_l[il]));
|
||||
|
||||
v_cur = ggml_transpose(ctx0, v_cur);
|
||||
}
|
||||
//cb(v_cache_view, "v_cache_view", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
|
||||
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
|
||||
}
|
||||
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = kv_self->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv_self->get_v(ctx0, il);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
|
||||
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
||||
//cb(q, "q", il);
|
||||
|
||||
ggml_tensor * k =
|
||||
ggml_view_3d(ctx0, kv_self->k_l[il],
|
||||
n_embd_head_k, n_kv, n_head_kv,
|
||||
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
||||
0);
|
||||
//cb(k, "k", il);
|
||||
|
||||
ggml_tensor * v = !v_trans ?
|
||||
ggml_view_3d(ctx0, kv_self->v_l[il],
|
||||
n_embd_head_v, n_kv, n_head_kv,
|
||||
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
|
||||
0) :
|
||||
ggml_view_3d(ctx0, kv_self->v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv_self->v_l[il])*n_ctx,
|
||||
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
}
|
||||
|
||||
if (wo_b) {
|
||||
cur = ggml_add(ctx0, cur, wo_b);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
|
||||
|
||||
{
|
||||
const auto n_kv = kv_self->get_kv_base()->get_n();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
if (hparams.n_swa_pattern > 1) {
|
||||
GGML_ASSERT(hparams.n_swa > 0 && "Use llama_kv_cache_unified for non-SWA");
|
||||
|
||||
const auto n_kv = kv_self->get_kv_swa()->get_n();
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
ggml_build_forward_expand(gf, q_cur);
|
||||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
||||
|
||||
const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = kv->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4) {
|
||||
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
if (wo_b) {
|
||||
//cb(cur, "kqv_wo", il);
|
||||
}
|
||||
@@ -1533,17 +1418,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask_cross();
|
||||
|
||||
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
||||
//cb(q, "q", il);
|
||||
|
||||
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
||||
//cb(k, "k", il);
|
||||
|
||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
||||
//cb(k, "v", il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = k_cur;
|
||||
ggml_tensor * v = v_cur;
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
@@ -1711,3 +1590,30 @@ void llm_graph_context::build_pooling(
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||
const int64_t max_distance = 128;
|
||||
|
||||
if (bidirectional) {
|
||||
n_buckets >>= 1;
|
||||
}
|
||||
|
||||
const int64_t max_exact = n_buckets >> 1;
|
||||
|
||||
int32_t relative_position = x - y;
|
||||
int32_t relative_bucket = 0;
|
||||
|
||||
if (bidirectional) {
|
||||
relative_bucket += (relative_position > 0) * n_buckets;
|
||||
relative_position = abs(relative_position);
|
||||
} else {
|
||||
relative_position = -std::min<int32_t>(relative_position, 0);
|
||||
}
|
||||
|
||||
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
||||
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
||||
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
||||
|
||||
return relative_bucket;
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ struct llama_cparams;
|
||||
|
||||
class llama_memory_i;
|
||||
class llama_kv_cache_unified;
|
||||
class llama_kv_cache_unified_iswa;
|
||||
class llama_kv_cache_recurrent;
|
||||
|
||||
// certain models (typically multi-modal) can produce different types of graphs
|
||||
@@ -255,6 +256,31 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_kv_unified_iswa(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_iswa * kv_self) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
kv_self(kv_self) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified_iswa() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
@@ -266,7 +292,7 @@ public:
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
const llama_kv_cache_unified_iswa * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||
@@ -298,6 +324,7 @@ class llm_graph_result_i {
|
||||
public:
|
||||
virtual ~llm_graph_result_i() = default;
|
||||
|
||||
virtual ggml_tensor * get_tokens() = 0;
|
||||
virtual ggml_tensor * get_logits() = 0;
|
||||
virtual ggml_tensor * get_embd() = 0;
|
||||
virtual ggml_tensor * get_embd_pooled() = 0;
|
||||
@@ -312,6 +339,7 @@ class llm_graph_result : public llm_graph_result_i {
|
||||
public:
|
||||
virtual ~llm_graph_result() = default;
|
||||
|
||||
ggml_tensor * get_tokens() override { return t_tokens; }
|
||||
ggml_tensor * get_logits() override { return t_logits; }
|
||||
ggml_tensor * get_embd() override { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
||||
@@ -328,6 +356,7 @@ public:
|
||||
}
|
||||
|
||||
// important graph nodes
|
||||
ggml_tensor * t_tokens = nullptr;
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
@@ -375,7 +404,6 @@ struct llm_graph_context {
|
||||
const int64_t n_layer;
|
||||
const int64_t n_rot;
|
||||
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
||||
const int64_t n_ctx_per_seq;
|
||||
const int64_t n_head;
|
||||
const int64_t n_head_kv;
|
||||
const int64_t n_embd_head_k;
|
||||
@@ -504,13 +532,12 @@ struct llm_graph_context {
|
||||
|
||||
ggml_tensor * build_attn_mha(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
|
||||
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
|
||||
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
|
||||
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
bool v_trans,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale) const;
|
||||
|
||||
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
||||
@@ -543,6 +570,21 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
@@ -593,3 +635,6 @@ struct llm_graph_context {
|
||||
ggml_tensor * cls_out,
|
||||
ggml_tensor * cls_out_b) const;
|
||||
};
|
||||
|
||||
// TODO: better name
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
|
||||
|
||||
@@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
||||
};
|
||||
|
||||
enum llama_swa_type {
|
||||
LLAMA_SWA_TYPE_NONE = 0,
|
||||
LLAMA_SWA_TYPE_STANDARD = 1,
|
||||
LLAMA_SWA_TYPE_CHUNKED = 2,
|
||||
};
|
||||
|
||||
struct llama_hparams_posnet {
|
||||
uint32_t n_embd;
|
||||
uint32_t n_layer;
|
||||
@@ -35,8 +41,6 @@ struct llama_hparams {
|
||||
uint32_t n_embd_features = 0;
|
||||
uint32_t n_layer;
|
||||
uint32_t n_rot;
|
||||
uint32_t n_swa = 0; // sliding window attention (SWA)
|
||||
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
|
||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||
uint32_t n_expert = 0;
|
||||
@@ -96,6 +100,12 @@ struct llama_hparams {
|
||||
|
||||
std::array<int, 4> rope_sections;
|
||||
|
||||
// Sliding Window Attention (SWA)
|
||||
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
|
||||
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
uint32_t ssm_d_inner = 0;
|
||||
@@ -116,11 +126,10 @@ struct llama_hparams {
|
||||
bool causal_attn = true;
|
||||
bool use_alibi = false;
|
||||
bool attn_soft_cap = false;
|
||||
bool use_kq_norm = true;
|
||||
|
||||
// llama4
|
||||
uint32_t n_moe_layer_step = 0;
|
||||
bool use_kq_norm = true;
|
||||
uint32_t n_attn_chunk = 0;
|
||||
// values below seems to be fixed on llama4
|
||||
uint32_t n_no_rope_layer_step = 4;
|
||||
uint32_t n_attn_temp_floor_scale = 8192;
|
||||
float f_attn_temp_scale = 0.1;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,7 @@
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
struct llama_cparams;
|
||||
@@ -40,6 +41,9 @@ struct llama_kv_cache : public llama_memory_i {
|
||||
// batch processing
|
||||
//
|
||||
|
||||
// =============================================================================================================
|
||||
// TODO: refactor and simplify this
|
||||
|
||||
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
||||
|
||||
// different KV caches require different batch splitting strategies
|
||||
@@ -48,6 +52,8 @@ struct llama_kv_cache : public llama_memory_i {
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
virtual bool find_slot(const llama_ubatch & batch) = 0;
|
||||
|
||||
// =============================================================================================================
|
||||
|
||||
// getters
|
||||
virtual int32_t get_n_tokens() const = 0;
|
||||
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||
@@ -87,38 +93,24 @@ private:
|
||||
// llama_kv_cache_unified
|
||||
//
|
||||
|
||||
// TODO: add notion of max sequences
|
||||
class llama_kv_cache_unified : public llama_kv_cache {
|
||||
public:
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t padding);
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t padding,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
|
||||
~llama_kv_cache_unified() = default;
|
||||
|
||||
@@ -130,10 +122,11 @@ public:
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
@@ -150,7 +143,6 @@ public:
|
||||
void set_full() override;
|
||||
|
||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
||||
|
||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
||||
|
||||
// updates the cache head
|
||||
@@ -169,32 +161,72 @@ public:
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
//
|
||||
// llama_kv_cache_unified specific API
|
||||
//
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
uint32_t get_n() const;
|
||||
uint32_t get_size() const;
|
||||
|
||||
std::vector<kv_cell> cells;
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
std::vector<ggml_tensor *> v_l;
|
||||
// store k_cur and v_cur in the cache based on the current head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
||||
|
||||
void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
|
||||
// TODO: replace with bitset uint64_t
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
struct kv_layer {
|
||||
// layer index in the model
|
||||
// note: can be different from the layer index in the KV cache
|
||||
uint32_t il;
|
||||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
};
|
||||
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool can_shift = false;
|
||||
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
// required padding
|
||||
uint32_t padding = 1;
|
||||
@@ -202,9 +234,29 @@ private:
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
// SWA
|
||||
uint32_t n_swa = 0;
|
||||
|
||||
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
std::vector<kv_cell> cells; // TODO: replace with `struct kv_cells`
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
// recovery information used to restore the KV cells to their original state in case of a failure
|
||||
struct {
|
||||
void clear() {
|
||||
cells.clear();
|
||||
}
|
||||
|
||||
std::unordered_map<uint32_t, kv_cell> cells;
|
||||
} recovery;
|
||||
|
||||
// defrag
|
||||
struct {
|
||||
std::vector<uint32_t> ids;
|
||||
@@ -213,17 +265,6 @@ private:
|
||||
// return true if cells have been moved
|
||||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
|
||||
// commit/restore cache
|
||||
struct slot_range {
|
||||
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
||||
uint32_t c1 = 0;
|
||||
};
|
||||
|
||||
// pending cell updates that are not yet committed
|
||||
struct {
|
||||
std::vector<slot_range> ranges;
|
||||
} pending;
|
||||
|
||||
// find how many cells are currently in use
|
||||
uint32_t cell_max() const;
|
||||
|
||||
@@ -232,6 +273,8 @@ private:
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
|
||||
|
||||
ggml_tensor * build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
@@ -258,6 +301,106 @@ private:
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa
|
||||
//
|
||||
|
||||
// utilizes two instances of llama_kv_cache_unified
|
||||
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
||||
// upon successful commit, the SWA cache removes old tokens outside the n_swa window
|
||||
|
||||
class llama_kv_cache_unified_iswa : public llama_kv_cache {
|
||||
public:
|
||||
llama_kv_cache_unified_iswa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
bool swa_full,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_batch,
|
||||
uint32_t padding);
|
||||
|
||||
~llama_kv_cache_unified_iswa() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
void clear() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
void restore() override;
|
||||
void commit() override;
|
||||
|
||||
bool update(llama_context & ctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
void set_full() override;
|
||||
|
||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
||||
|
||||
bool find_slot(const llama_ubatch & batch) override;
|
||||
|
||||
int32_t get_n_tokens() const override;
|
||||
int32_t get_used_cells() const override;
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos get_pos_max() const override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_kv_base() const;
|
||||
llama_kv_cache_unified * get_kv_swa () const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
bool do_prune = true;
|
||||
|
||||
struct {
|
||||
struct entry {
|
||||
llama_pos pmin;
|
||||
llama_pos pmax;
|
||||
};
|
||||
|
||||
void clear() {
|
||||
pos.clear();
|
||||
}
|
||||
|
||||
// used to perform SWA pruning of old tokens
|
||||
std::unordered_map<llama_seq_id, entry> pos;
|
||||
} pending;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent
|
||||
//
|
||||
@@ -305,6 +448,7 @@ public:
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
@@ -321,7 +465,6 @@ public:
|
||||
void set_full() override;
|
||||
|
||||
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
|
||||
|
||||
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
|
||||
|
||||
bool find_slot(const llama_ubatch & batch) override;
|
||||
@@ -343,11 +486,8 @@ public:
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
@@ -394,12 +534,3 @@ private:
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// kv cache view
|
||||
//
|
||||
|
||||
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
|
||||
|
||||
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);
|
||||
|
||||
@@ -7,8 +7,8 @@ struct llama_memory_params {
|
||||
ggml_type type_k;
|
||||
ggml_type type_v;
|
||||
|
||||
// parameters for other types of memory
|
||||
// ...
|
||||
// use full-size SWA cache
|
||||
bool swa_full;
|
||||
};
|
||||
|
||||
// general concept of LLM memory
|
||||
@@ -25,6 +25,7 @@ public:
|
||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||
|
||||
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
||||
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
||||
|
||||
virtual bool get_can_edit() const = 0;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user