mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-04 08:04:07 +00:00
Compare commits
106 Commits
b7767
...
gg/ngram-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c8a04576e | ||
|
|
003c90352d | ||
|
|
9f8401a533 | ||
|
|
bc33838037 | ||
|
|
351e798b2a | ||
|
|
a83c73a18a | ||
|
|
fc3cdf32ce | ||
|
|
7afdfc9b84 | ||
|
|
94eeb5967c | ||
|
|
b0311c16d2 | ||
|
|
dd23149dea | ||
|
|
72f416e973 | ||
|
|
8f80d1b254 | ||
|
|
142cbe2ac6 | ||
|
|
1f8d36665d | ||
|
|
a3300937e5 | ||
|
|
f895bca71a | ||
|
|
56f3ebf38e | ||
|
|
fd4d803c60 | ||
|
|
288ab50597 | ||
|
|
8ea068e5f8 | ||
|
|
0c21677e43 | ||
|
|
9ac881767c | ||
|
|
0440bfd160 | ||
|
|
0bf5636938 | ||
|
|
924517dd38 | ||
|
|
af382c384a | ||
|
|
bcb43163ae | ||
|
|
d9c6ce46f7 | ||
|
|
70d860824a | ||
|
|
080b161995 | ||
|
|
1243f93a2d | ||
|
|
24bc238303 | ||
|
|
16639ba217 | ||
|
|
9981c30130 | ||
|
|
cb3a40277a | ||
|
|
e9fd8dcab4 | ||
|
|
4e5b83b226 | ||
|
|
bb02f74c61 | ||
|
|
a1584ac80f | ||
|
|
1e29af4ea5 | ||
|
|
eb43748b05 | ||
|
|
b38eb5907c | ||
|
|
456268fa7f | ||
|
|
907d094f9e | ||
|
|
f1f6584ce6 | ||
|
|
917f4bb14b | ||
|
|
38f7c28795 | ||
|
|
e3e809cc01 | ||
|
|
1faeb628db | ||
|
|
1fb2658b0d | ||
|
|
8f91ca54ec | ||
|
|
81ab64f3c8 | ||
|
|
8af1f5f430 | ||
|
|
557515be1e | ||
|
|
cb6caca191 | ||
|
|
b5b8fa1c8b | ||
|
|
a14b960bc7 | ||
|
|
091a46cb8d | ||
|
|
a3e812811d | ||
|
|
51fa458a92 | ||
|
|
a5eaa1d6a3 | ||
|
|
e2baf02162 | ||
|
|
e34d6d03b2 | ||
|
|
9c96465f99 | ||
|
|
4e595b250a | ||
|
|
0e4ebeb057 | ||
|
|
8b30840703 | ||
|
|
9eb5bfec1a | ||
|
|
c6926d1d95 | ||
|
|
b70d251076 | ||
|
|
5516b9c16a | ||
|
|
94242a62c0 | ||
|
|
6b99a223e3 | ||
|
|
77078e80e5 | ||
|
|
c301172f66 | ||
|
|
3802d3c78f | ||
|
|
9da3dcd753 | ||
|
|
bd544c94a3 | ||
|
|
14be5a39b1 | ||
|
|
fbbf3ad190 | ||
|
|
33f890e579 | ||
|
|
067b8d7af3 | ||
|
|
50b7f076a5 | ||
|
|
ad8d85bd94 | ||
|
|
12a4a47e6a | ||
|
|
37c35f0e1c | ||
|
|
5bd341c9a1 | ||
|
|
1c7cf94b22 | ||
|
|
2c1f199653 | ||
|
|
d1e3556481 | ||
|
|
08f3f4a8a3 | ||
|
|
271191906c | ||
|
|
7dee9ff59a | ||
|
|
6df686bee6 | ||
|
|
1706a6d7c6 | ||
|
|
959ecf7f23 | ||
|
|
4037093c66 | ||
|
|
18361c579c | ||
|
|
365a3e8c31 | ||
|
|
3d55846a5c | ||
|
|
287a33017b | ||
|
|
293a1565dc | ||
|
|
fe44d35574 | ||
|
|
bbcdac0189 | ||
|
|
d03c45c9c5 |
12
.github/workflows/build-cache.yml
vendored
12
.github/workflows/build-cache.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Get latest Vulkan SDK version
|
||||
id: vulkan_sdk_version
|
||||
@@ -24,7 +24,7 @@ jobs:
|
||||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
@@ -47,10 +47,10 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-toolchain
|
||||
with:
|
||||
path: ./spacemit_toolchain
|
||||
@@ -73,10 +73,10 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-rocm
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
|
||||
2
.github/workflows/build-cmake-pkg.yml
vendored
2
.github/workflows/build-cmake-pkg.yml
vendored
@@ -7,7 +7,7 @@ jobs:
|
||||
linux:
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
14
.github/workflows/build-linux-cross.yml
vendored
14
.github/workflows/build-linux-cross.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@v6
|
||||
# - name: Setup Riscv
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture riscv64
|
||||
@@ -52,7 +52,7 @@ jobs:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@v6
|
||||
# - name: Setup Riscv
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture riscv64
|
||||
@@ -99,7 +99,7 @@ jobs:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@v6
|
||||
# - name: Setup Arm64
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture arm64
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- name: Setup LoongArch
|
||||
run: |
|
||||
rm -f /etc/apt/sources.list.d/*
|
||||
@@ -201,7 +201,7 @@ jobs:
|
||||
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- name: Setup LoongArch
|
||||
run: |
|
||||
rm -f /etc/apt/sources.list.d/*
|
||||
@@ -262,10 +262,10 @@ jobs:
|
||||
SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Use SpacemiT Toolchain Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-toolchain
|
||||
with:
|
||||
path: ./spacemit_toolchain
|
||||
|
||||
114
.github/workflows/build.yml
vendored
114
.github/workflows/build.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -99,7 +99,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -189,7 +189,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -269,7 +269,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -317,7 +317,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
# - name: ccache
|
||||
# uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -380,7 +380,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -414,7 +414,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -436,7 +436,7 @@ jobs:
|
||||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Use Vulkan SDK Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
@@ -472,7 +472,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -494,7 +494,7 @@ jobs:
|
||||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Use Vulkan SDK Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
@@ -543,7 +543,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -585,7 +585,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
@@ -616,7 +616,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
@@ -644,7 +644,7 @@ jobs:
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: add oneAPI to apt
|
||||
shell: bash
|
||||
@@ -668,7 +668,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -693,7 +693,7 @@ jobs:
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: add oneAPI to apt
|
||||
shell: bash
|
||||
@@ -717,7 +717,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -749,7 +749,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -781,7 +781,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -813,7 +813,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -843,7 +843,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -853,7 +853,7 @@ jobs:
|
||||
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
|
||||
|
||||
- name: Download xcframework artifact
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
@@ -885,7 +885,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -954,7 +954,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1053,7 +1053,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install dependencies
|
||||
env:
|
||||
@@ -1092,7 +1092,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1145,7 +1145,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1177,7 +1177,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
@@ -1187,7 +1187,7 @@ jobs:
|
||||
7z x data.tar
|
||||
|
||||
- name: Use ROCm Installation Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-rocm
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
@@ -1239,7 +1239,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Xcode
|
||||
uses: maxim-lobanov/setup-xcode@v1
|
||||
@@ -1269,7 +1269,7 @@ jobs:
|
||||
./build-xcframework.sh
|
||||
|
||||
- name: Upload xcframework artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
@@ -1285,7 +1285,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
# Disabled due to size (400MB) and always 0 cache hits
|
||||
# - name: ccache
|
||||
@@ -1295,7 +1295,7 @@ jobs:
|
||||
# evict-old-files: 1d
|
||||
|
||||
- name: Set up JDK
|
||||
uses: actions/setup-java@v3
|
||||
uses: actions/setup-java@v5
|
||||
with:
|
||||
java-version: 17
|
||||
distribution: zulu
|
||||
@@ -1327,7 +1327,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install OpenCL Headers and Libs
|
||||
id: install_opencl
|
||||
@@ -1402,7 +1402,7 @@ jobs:
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -1460,7 +1460,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1486,7 +1486,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1512,7 +1512,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1538,7 +1538,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1564,7 +1564,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1590,7 +1590,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1604,7 +1604,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1618,7 +1618,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1632,7 +1632,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1645,7 +1645,7 @@ jobs:
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
# uses: actions/checkout@v6
|
||||
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
@@ -1659,7 +1659,7 @@ jobs:
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
# uses: actions/checkout@v6
|
||||
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
@@ -1673,7 +1673,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1686,7 +1686,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
@@ -1714,7 +1714,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1728,7 +1728,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1773,7 +1773,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Check environment
|
||||
run: |
|
||||
@@ -1875,7 +1875,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
@@ -1969,7 +1969,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
@@ -2043,7 +2043,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
@@ -2089,7 +2089,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
|
||||
6
.github/workflows/check-vendor.yml
vendored
6
.github/workflows/check-vendor.yml
vendored
@@ -19,16 +19,16 @@ on:
|
||||
|
||||
jobs:
|
||||
check-vendor:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
|
||||
4
.github/workflows/close-issue.yml
vendored
4
.github/workflows/close-issue.yml
vendored
@@ -10,12 +10,12 @@ permissions:
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
|
||||
days-before-issue-stale: 30
|
||||
|
||||
4
.github/workflows/copilot-setup-steps.yml
vendored
4
.github/workflows/copilot-setup-steps.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
# If you do not check out your code, Copilot will do this for you.
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
sudo chmod +x /usr/local/bin/git-clang-format
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
||||
6
.github/workflows/docker.yml
vendored
6
.github/workflows/docker.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0 # preserve git history, so we can determine the build number
|
||||
|
||||
@@ -63,7 +63,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
@@ -208,7 +208,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
4
.github/workflows/editorconfig.yml
vendored
4
.github/workflows/editorconfig.yml
vendored
@@ -20,9 +20,9 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
editorconfig:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- uses: editorconfig-checker/action-editorconfig-checker@v2
|
||||
with:
|
||||
version: v3.0.3
|
||||
|
||||
6
.github/workflows/gguf-publish.yml
vendored
6
.github/workflows/gguf-publish.yml
vendored
@@ -21,12 +21,12 @@ on:
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.9.x'
|
||||
- name: Install dependencies
|
||||
|
||||
6
.github/workflows/labeler.yml
vendored
6
.github/workflows/labeler.yml
vendored
@@ -7,11 +7,11 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
repository: "ggml-org/llama.cpp"
|
||||
- uses: actions/labeler@v5
|
||||
- uses: actions/labeler@v6
|
||||
with:
|
||||
configuration-path: '.github/labeler.yml'
|
||||
|
||||
6
.github/workflows/pre-tokenizer-hashes.yml
vendored
6
.github/workflows/pre-tokenizer-hashes.yml
vendored
@@ -12,14 +12,14 @@ on:
|
||||
|
||||
jobs:
|
||||
pre-tokenizer-hashes:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
||||
@@ -20,13 +20,13 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
python-check-requirements:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
name: check-requirements
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Python environment
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Run check-requirements.sh script
|
||||
|
||||
6
.github/workflows/python-lint.yml
vendored
6
.github/workflows/python-lint.yml
vendored
@@ -15,13 +15,13 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
flake8-lint:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
name: Lint
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Python environment
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: flake8 Lint
|
||||
|
||||
8
.github/workflows/python-type-check.yml
vendored
8
.github/workflows/python-type-check.yml
vendored
@@ -24,14 +24,12 @@ jobs:
|
||||
name: pyright type-check
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Python environment
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Python dependencies
|
||||
# TODO: use a venv
|
||||
run: pip install -r requirements/requirements-all.txt
|
||||
pip-install: -r requirements/requirements-all.txt
|
||||
- name: Type-check with Pyright
|
||||
uses: jakebailey/pyright-action@v2
|
||||
with:
|
||||
|
||||
56
.github/workflows/release.yml
vendored
56
.github/workflows/release.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -63,7 +63,7 @@ jobs:
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz
|
||||
name: llama-bin-macos-arm64.tar.gz
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -111,7 +111,7 @@ jobs:
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz
|
||||
name: llama-bin-macos-x64.tar.gz
|
||||
@@ -133,7 +133,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -173,7 +173,7 @@ jobs:
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-${{ matrix.build }}.tar.gz
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -226,7 +226,7 @@ jobs:
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz
|
||||
name: llama-bin-ubuntu-vulkan-x64.tar.gz
|
||||
@@ -242,7 +242,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -278,7 +278,7 @@ jobs:
|
||||
7z a -snl llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-cpu-${{ matrix.arch }}.zip
|
||||
name: llama-bin-win-cpu-${{ matrix.arch }}.zip
|
||||
@@ -305,7 +305,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -360,7 +360,7 @@ jobs:
|
||||
7z a -snl llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
|
||||
name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
|
||||
@@ -375,7 +375,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -416,7 +416,7 @@ jobs:
|
||||
7z a -snl llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
@@ -431,7 +431,7 @@ jobs:
|
||||
7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\*
|
||||
|
||||
- name: Upload Cuda runtime
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
@@ -451,7 +451,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -511,7 +511,7 @@ jobs:
|
||||
7z a -snl llama-bin-win-sycl-x64.zip ./build/bin/*
|
||||
|
||||
- name: Upload the release package
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-sycl-x64.zip
|
||||
name: llama-bin-win-sycl-x64.zip
|
||||
@@ -531,7 +531,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
@@ -542,7 +542,7 @@ jobs:
|
||||
|
||||
- name: Cache ROCm Installation
|
||||
id: cache-rocm
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
|
||||
@@ -617,7 +617,7 @@ jobs:
|
||||
7z a -snl llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-hip-${{ matrix.name }}-x64.zip
|
||||
name: llama-bin-win-hip-${{ matrix.name }}-x64.zip
|
||||
@@ -627,7 +627,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -672,7 +672,7 @@ jobs:
|
||||
zip -r -y llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
||||
name: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
||||
@@ -703,7 +703,7 @@ jobs:
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -763,7 +763,7 @@ jobs:
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz
|
||||
@@ -794,7 +794,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -804,7 +804,7 @@ jobs:
|
||||
|
||||
- name: Download artifacts
|
||||
id: download-artifact
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
path: ./artifact
|
||||
merge-multiple: true
|
||||
@@ -887,7 +887,7 @@ jobs:
|
||||
|
||||
- name: Upload release
|
||||
id: upload_release
|
||||
uses: actions/github-script@v3
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{secrets.GITHUB_TOKEN}}
|
||||
script: |
|
||||
@@ -897,7 +897,7 @@ jobs:
|
||||
for (let file of await fs.readdirSync('./release')) {
|
||||
if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) {
|
||||
console.log('uploadReleaseAsset', file);
|
||||
await github.repos.uploadReleaseAsset({
|
||||
await github.rest.repos.uploadReleaseAsset({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
release_id: release_id,
|
||||
|
||||
10
.github/workflows/server-webui.yml
vendored
10
.github/workflows/server-webui.yml
vendored
@@ -37,14 +37,14 @@ jobs:
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Setup Node.js
|
||||
id: node
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
@@ -131,14 +131,14 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
@@ -148,7 +148,7 @@ jobs:
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
- name: Setup Node.js for WebUI
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
|
||||
12
.github/workflows/server.yml
vendored
12
.github/workflows/server.yml
vendored
@@ -64,7 +64,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
@@ -72,12 +72,12 @@ jobs:
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON -DGGML_SCHED_NO_REALLOC=ON
|
||||
cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
@@ -100,7 +100,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
@@ -108,12 +108,12 @@ jobs:
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON -DGGML_SCHED_NO_REALLOC=ON
|
||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
||||
6
.github/workflows/update-ops-docs.yml
vendored
6
.github/workflows/update-ops-docs.yml
vendored
@@ -14,14 +14,14 @@ on:
|
||||
|
||||
jobs:
|
||||
update-ops-docs:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
|
||||
4
.github/workflows/winget.yml
vendored
4
.github/workflows/winget.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
jobs:
|
||||
update:
|
||||
name: Update Winget Package
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
if: github.repository_owner == 'ggml-org'
|
||||
|
||||
steps:
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
- name: Find latest release
|
||||
id: find_latest_release
|
||||
uses: actions/github-script@v6
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const { data: releases } = await github.rest.repos.listReleases({
|
||||
|
||||
@@ -15,8 +15,10 @@
|
||||
/common/common.* @ggerganov
|
||||
/common/console.* @ggerganov
|
||||
/common/http.* @angt
|
||||
/common/jinja/ @ngxson @CISC @aldehir
|
||||
/common/llguidance.* @ggerganov
|
||||
/common/log.* @ggerganov
|
||||
/common/ngram-map.* @srogmann
|
||||
/common/peg-parser.* @aldehir
|
||||
/common/sampling.* @ggerganov
|
||||
/common/speculative.* @ggerganov
|
||||
|
||||
@@ -132,6 +132,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
|
||||
- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
|
||||
- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
|
||||
- [x] [RWKV-7](https://huggingface.co/collections/shoumenchougou/rwkv7-gxx-gguf)
|
||||
- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
|
||||
- [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1)
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
|
||||
@@ -254,7 +254,7 @@ function gg_run_ctest_release {
|
||||
(time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log
|
||||
|
||||
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
(time ctest --output-on-failure -L main ) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||
(time ctest --output-on-failure -L 'main|python' ) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||
else
|
||||
(time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||
fi
|
||||
|
||||
@@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
|
||||
log.h
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
ngram-map.cpp
|
||||
ngram-map.h
|
||||
peg-parser.cpp
|
||||
peg-parser.h
|
||||
preset.cpp
|
||||
|
||||
117
common/arg.cpp
117
common/arg.cpp
@@ -6,6 +6,7 @@
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
#include "preset.h"
|
||||
|
||||
// fix problem with std::min and std::max
|
||||
@@ -1216,21 +1217,25 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"-lcs", "--lookup-cache-static"}, "FNAME",
|
||||
"path to static lookup cache to use for lookup decoding (not updated by generation)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.lookup_cache_static = value;
|
||||
params.speculative.lookup_cache_static = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-lcd", "--lookup-cache-dynamic"}, "FNAME",
|
||||
"path to dynamic lookup cache to use for lookup decoding (updated by generation)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.lookup_cache_dynamic = value;
|
||||
params.speculative.lookup_cache_dynamic = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-c", "--ctx-size"}, "N",
|
||||
string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
|
||||
[](common_params & params, int value) {
|
||||
params.n_ctx = value;
|
||||
if (value == 0) {
|
||||
// disable context reduction in llama_params_fit if the user explicitly requests the full context size:
|
||||
params.fit_params_min_ctx = UINT32_MAX;
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_ARG_CTX_SIZE"));
|
||||
add_opt(common_arg(
|
||||
@@ -1573,7 +1578,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--temp"}, "N",
|
||||
string_format("temperature (default: %.1f)", (double)params.sampling.temp),
|
||||
string_format("temperature (default: %.2f)", (double)params.sampling.temp),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.temp = std::stof(value);
|
||||
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
|
||||
@@ -1590,7 +1595,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam().set_env("LLAMA_ARG_TOP_K"));
|
||||
add_opt(common_arg(
|
||||
{"--top-p"}, "N",
|
||||
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
|
||||
string_format("top-p sampling (default: %.2f, 1.0 = disabled)", (double)params.sampling.top_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.top_p = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
|
||||
@@ -1598,7 +1603,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--min-p"}, "N",
|
||||
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
|
||||
string_format("min-p sampling (default: %.2f, 0.0 = disabled)", (double)params.sampling.min_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.min_p = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
|
||||
@@ -1606,14 +1611,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--top-nsigma"}, "N",
|
||||
string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma),
|
||||
string_format("top-n-sigma sampling (default: %.2f, -1.0 = disabled)", params.sampling.top_n_sigma),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.top_n_sigma = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-probability"}, "N",
|
||||
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
||||
string_format("xtc probability (default: %.2f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.xtc_probability = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
|
||||
@@ -1621,7 +1626,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-threshold"}, "N",
|
||||
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
|
||||
string_format("xtc threshold (default: %.2f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.xtc_threshold = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
|
||||
@@ -1629,7 +1634,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--typical"}, "N",
|
||||
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p),
|
||||
string_format("locally typical sampling, parameter p (default: %.2f, 1.0 = disabled)", (double)params.sampling.typ_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.typ_p = std::stof(value);
|
||||
}
|
||||
@@ -1648,7 +1653,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--repeat-penalty"}, "N",
|
||||
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
|
||||
string_format("penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.penalty_repeat = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
|
||||
@@ -1656,21 +1661,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--presence-penalty"}, "N",
|
||||
string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present),
|
||||
string_format("repeat alpha presence penalty (default: %.2f, 0.0 = disabled)", (double)params.sampling.penalty_present),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.penalty_present = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--frequency-penalty"}, "N",
|
||||
string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq),
|
||||
string_format("repeat alpha frequency penalty (default: %.2f, 0.0 = disabled)", (double)params.sampling.penalty_freq),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.penalty_freq = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--dry-multiplier"}, "N",
|
||||
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier),
|
||||
string_format("set DRY sampling multiplier (default: %.2f, 0.0 = disabled)", (double)params.sampling.dry_multiplier),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.dry_multiplier = std::stof(value);
|
||||
}
|
||||
@@ -1751,14 +1756,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--dynatemp-range"}, "N",
|
||||
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
|
||||
string_format("dynamic temperature range (default: %.2f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.dynatemp_range = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--dynatemp-exp"}, "N",
|
||||
string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent),
|
||||
string_format("dynamic temperature exponent (default: %.2f)", (double)params.sampling.dynatemp_exponent),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.dynatemp_exponent = std::stof(value);
|
||||
}
|
||||
@@ -1774,7 +1779,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--mirostat-lr"}, "N",
|
||||
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
|
||||
string_format("Mirostat learning rate, parameter eta (default: %.2f)", (double)params.sampling.mirostat_eta),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.mirostat_eta = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
|
||||
@@ -1782,7 +1787,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--mirostat-ent"}, "N",
|
||||
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
|
||||
string_format("Mirostat target entropy, parameter tau (default: %.2f)", (double)params.sampling.mirostat_tau),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.mirostat_tau = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
|
||||
@@ -1916,28 +1921,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_env("LLAMA_ARG_YARN_ORIG_CTX"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-ext-factor"}, "N",
|
||||
string_format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
|
||||
string_format("YaRN: extrapolation mix factor (default: %.2f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_ext_factor = std::stof(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_YARN_EXT_FACTOR"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-attn-factor"}, "N",
|
||||
string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor),
|
||||
string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.2f)", (double)params.yarn_attn_factor),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_attn_factor = std::stof(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_YARN_ATTN_FACTOR"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-beta-slow"}, "N",
|
||||
string_format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow),
|
||||
string_format("YaRN: high correction dim or alpha (default: %.2f)", (double)params.yarn_beta_slow),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_beta_slow = std::stof(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_YARN_BETA_SLOW"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-beta-fast"}, "N",
|
||||
string_format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast),
|
||||
string_format("YaRN: low correction dim or beta (default: %.2f)", (double)params.yarn_beta_fast),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_beta_fast = std::stof(value);
|
||||
}
|
||||
@@ -3331,14 +3336,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_MIN"));
|
||||
add_opt(common_arg(
|
||||
{"--draft-p-split"}, "P",
|
||||
string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split),
|
||||
string_format("speculative decoding split probability (default: %.2f)", (double)params.speculative.p_split),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.p_split = std::stof(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_DRAFT_P_SPLIT"));
|
||||
add_opt(common_arg(
|
||||
{"--draft-p-min"}, "P",
|
||||
string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min),
|
||||
string_format("minimum speculative decoding probability (greedy) (default: %.2f)", (double)params.speculative.p_min),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.p_min = std::stof(value);
|
||||
}
|
||||
@@ -3392,6 +3397,68 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.speculative.replacements.push_back({ tgt, dft });
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draftless"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-map-mod]",
|
||||
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
|
||||
common_speculative_type_to_str(params.speculative.type).c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "none") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
} else if (value == "ngram-cache") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
|
||||
} else if (value == "ngram-simple") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
|
||||
} else if (value == "ngram-map-k") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
|
||||
} else if (value == "ngram-map-k4v") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
|
||||
} else if (value == "ngram-map-mod") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD;
|
||||
} else {
|
||||
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-n"}, "N",
|
||||
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1 || value > 1024) {
|
||||
throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
|
||||
}
|
||||
params.speculative.ngram_size_n = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-m"}, "N",
|
||||
string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1 || value > 1024) {
|
||||
throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
|
||||
}
|
||||
params.speculative.ngram_size_m = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-check-rate"}, "N",
|
||||
string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1) {
|
||||
throw std::invalid_argument("ngram check rate must be at least 1");
|
||||
}
|
||||
params.speculative.ngram_check_rate = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-min-hits"}, "N",
|
||||
string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1) {
|
||||
throw std::invalid_argument("ngram min hits must be at least 1");
|
||||
}
|
||||
params.speculative.ngram_min_hits = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
|
||||
string_format(
|
||||
|
||||
@@ -129,7 +129,7 @@ static void parse_json_tool_calls(
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
|
||||
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax)
|
||||
: input_(input), is_partial_(is_partial), syntax_(syntax)
|
||||
{
|
||||
result_.role = "assistant";
|
||||
@@ -1611,7 +1611,7 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
builder.finish();
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax) {
|
||||
if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
|
||||
@@ -1630,12 +1630,12 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
|
||||
}
|
||||
auto msg = builder.result();
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax) {
|
||||
if (parser.empty()) {
|
||||
throw std::runtime_error("Failed to parse due to missing parser definition.");
|
||||
}
|
||||
@@ -1663,7 +1663,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
}
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "json-partial.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
@@ -19,20 +19,20 @@ class common_chat_msg_partial_exception : public std::runtime_error {
|
||||
class common_chat_msg_parser {
|
||||
std::string input_;
|
||||
bool is_partial_;
|
||||
common_chat_syntax syntax_;
|
||||
common_chat_parser_params syntax_; // TODO: rename to params
|
||||
std::string healing_marker_;
|
||||
|
||||
size_t pos_ = 0;
|
||||
common_chat_msg result_;
|
||||
|
||||
public:
|
||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
const std::string & input() const { return input_; }
|
||||
size_t pos() const { return pos_; }
|
||||
const std::string & healing_marker() const { return healing_marker_; }
|
||||
const bool & is_partial() const { return is_partial_; }
|
||||
const common_chat_msg & result() const { return result_; }
|
||||
const common_chat_syntax & syntax() const { return syntax_; }
|
||||
const common_chat_parser_params & syntax() const { return syntax_; }
|
||||
|
||||
void move_to(size_t pos) {
|
||||
if (pos > input_.size()) {
|
||||
|
||||
240
common/chat.cpp
240
common/chat.cpp
@@ -7,9 +7,6 @@
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
// #include <minja/chat-template.hpp>
|
||||
// #include <minja/minja.hpp>
|
||||
|
||||
#include "jinja/parser.h"
|
||||
#include "jinja/value.h"
|
||||
#include "jinja/runtime.h"
|
||||
@@ -56,39 +53,73 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) {
|
||||
return !msg.content.empty() || !msg.tool_calls.empty();
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_msg::to_json_oaicompat() const
|
||||
{
|
||||
json message {
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!reasoning_content.empty()) {
|
||||
message["reasoning_content"] = reasoning_content;
|
||||
json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
if (!content.empty() && !content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
}
|
||||
if (content.empty() && !tool_calls.empty()) {
|
||||
message["content"] = json();
|
||||
json jmsg {
|
||||
{"role", role},
|
||||
};
|
||||
if (!content.empty()) {
|
||||
jmsg["content"] = content;
|
||||
} else if (!content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
for (const auto & part : content_parts) {
|
||||
if (part.type != "text") {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
text += '\n';
|
||||
}
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
} else {
|
||||
auto & parts = jmsg["content"] = json::array();
|
||||
for (const auto & part : content_parts) {
|
||||
parts.push_back({
|
||||
{"type", part.type},
|
||||
{"text", part.text},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
message["content"] = content;
|
||||
jmsg["content"] = "";
|
||||
}
|
||||
if (!reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = reasoning_content;
|
||||
}
|
||||
if (!tool_name.empty()) {
|
||||
jmsg["name"] = tool_name;
|
||||
}
|
||||
if (!tool_call_id.empty()) {
|
||||
jmsg["tool_call_id"] = tool_call_id;
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
auto arr = json::array();
|
||||
for (const auto & tc : tool_calls) {
|
||||
arr.push_back({
|
||||
jmsg["tool_calls"] = json::array();
|
||||
auto & jtool_calls = jmsg["tool_calls"];
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
json tc {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
{"name", tool_call.name},
|
||||
{"arguments", tool_call.arguments},
|
||||
}},
|
||||
{"id", tc.id},
|
||||
// // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// // We only generate a random id for the ones that don't generate one by themselves
|
||||
// // (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
// {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
});
|
||||
};
|
||||
if (!tool_call.id.empty()) {
|
||||
tc["id"] = tool_call.id;
|
||||
}
|
||||
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// We only generate a random id for the ones that don't generate one by themselves
|
||||
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
// {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
jtool_calls.push_back(tc);
|
||||
}
|
||||
message["tool_calls"] = arr;
|
||||
}
|
||||
return message;
|
||||
|
||||
return jmsg;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) {
|
||||
@@ -256,7 +287,6 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
|
||||
return rendered_no_thinking.prompt != rendered_with_thinking.prompt;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
||||
std::vector<common_chat_msg> msgs;
|
||||
|
||||
@@ -350,80 +380,15 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
||||
return msgs;
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
||||
json messages = json::array();
|
||||
for (const auto & msg : msgs) {
|
||||
if (!msg.content.empty() && !msg.content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
}
|
||||
json jmsg {
|
||||
{"role", msg.role},
|
||||
};
|
||||
if (!msg.content.empty()) {
|
||||
jmsg["content"] = msg.content;
|
||||
} else if (!msg.content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
for (const auto & part : msg.content_parts) {
|
||||
if (part.type != "text") {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
text += '\n';
|
||||
}
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
} else {
|
||||
auto & parts = jmsg["content"] = json::array();
|
||||
for (const auto & part : msg.content_parts) {
|
||||
parts.push_back({
|
||||
{"type", part.type},
|
||||
{"text", part.text},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
jmsg["content"] = "";
|
||||
}
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = msg.reasoning_content;
|
||||
}
|
||||
if (!msg.tool_name.empty()) {
|
||||
jmsg["name"] = msg.tool_name;
|
||||
}
|
||||
if (!msg.tool_call_id.empty()) {
|
||||
jmsg["tool_call_id"] = msg.tool_call_id;
|
||||
}
|
||||
if (!msg.tool_calls.empty()) {
|
||||
auto & tool_calls = jmsg["tool_calls"] = json::array();
|
||||
for (const auto & tool_call : msg.tool_calls) {
|
||||
json tc {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tool_call.name},
|
||||
{"arguments", tool_call.arguments},
|
||||
}},
|
||||
};
|
||||
if (!tool_call.id.empty()) {
|
||||
tc["id"] = tool_call.id;
|
||||
}
|
||||
tool_calls.push_back(tc);
|
||||
}
|
||||
}
|
||||
json jmsg = msg.to_json_oaicompat(concat_typed_text);
|
||||
messages.push_back(jmsg);
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
|
||||
return common_chat_msgs_parse_oaicompat(json::parse(messages));
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
||||
std::vector<common_chat_tool> result;
|
||||
|
||||
@@ -459,12 +424,6 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
|
||||
return common_chat_tools_parse_oaicompat(json::parse(tools));
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
||||
if (tools.empty()) {
|
||||
return json();
|
||||
@@ -484,7 +443,7 @@ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & t
|
||||
return result;
|
||||
}
|
||||
|
||||
template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json delta = json::object();
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
delta["reasoning_content"] = diff.reasoning_content_delta;
|
||||
@@ -601,18 +560,18 @@ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmp
|
||||
return tmpls->has_explicit_template;
|
||||
}
|
||||
|
||||
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
|
||||
if (variant != nullptr) {
|
||||
if (strcmp(variant, "tool_use") == 0) {
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant) {
|
||||
if (!variant.empty()) {
|
||||
if (variant == "tool_use") {
|
||||
if (tmpls->template_tool_use) {
|
||||
return tmpls->template_tool_use->source().c_str();
|
||||
return tmpls->template_tool_use->source();
|
||||
}
|
||||
return nullptr;
|
||||
return "";
|
||||
} else {
|
||||
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
|
||||
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str());
|
||||
}
|
||||
}
|
||||
return tmpls->template_default->source().c_str();
|
||||
return tmpls->template_default->source();
|
||||
}
|
||||
|
||||
common_chat_templates_ptr common_chat_templates_init(
|
||||
@@ -2691,6 +2650,51 @@ static common_chat_params common_chat_params_init_exaone_moe(const common_chat_t
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_translate_gemma(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// This template does not support tools or reasoning
|
||||
// we just need to transform the messages into the correct schema
|
||||
|
||||
templates_params inputs_new = inputs;
|
||||
json & messages = inputs_new.messages;
|
||||
|
||||
// default to chat_template_kwargs, or en-GB if not specified
|
||||
std::string default_src_lang = inputs.extra_context.value("source_lang_code", "en-GB");
|
||||
std::string default_tgt_lang = inputs.extra_context.value("target_lang_code", "en-GB");
|
||||
|
||||
GGML_ASSERT(messages.is_array());
|
||||
for (auto & message : messages) {
|
||||
if (message.contains("role") && message["role"].get<std::string>() != "user") {
|
||||
continue;
|
||||
}
|
||||
if (!message.contains("content")) {
|
||||
message["content"] = json::array();
|
||||
}
|
||||
if (message.contains("content") && !message["content"].is_array()) {
|
||||
auto content_str = message["content"].get<std::string>();
|
||||
// default to en-GB if not specified (to make common_chat_format_example works)
|
||||
auto src_lang = message.contains("source_lang_code")
|
||||
? message["source_lang_code"].get<std::string>() : default_src_lang;
|
||||
auto tgt_lang = message.contains("target_lang_code")
|
||||
? message["target_lang_code"].get<std::string>() : default_tgt_lang;
|
||||
message["content"] = json::array({
|
||||
json{
|
||||
{"type", "text"},
|
||||
{"text", content_str},
|
||||
{"source_lang_code", src_lang},
|
||||
{"target_lang_code", tgt_lang},
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs_new, std::nullopt, std::nullopt);
|
||||
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
@@ -2867,13 +2871,13 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
const struct common_chat_templates_inputs & inputs)
|
||||
{
|
||||
templates_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
||||
? *tmpls->template_tool_use
|
||||
: *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
||||
params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
@@ -2943,6 +2947,10 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
src.find("<arg_value>") != std::string::npos &&
|
||||
params.json_schema.is_null()) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
if (!params.extra_context.contains("clear_thinking")) {
|
||||
// by default, do not clear reasoning_content (added since GLM-4.7)
|
||||
params.extra_context["clear_thinking"] = false;
|
||||
}
|
||||
return common_chat_params_init_glm_4_5(tmpl, params);
|
||||
}
|
||||
|
||||
@@ -3082,6 +3090,12 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_solar_open(tmpl, params);
|
||||
}
|
||||
|
||||
// TranslateGemma
|
||||
if (src.find("[source_lang_code]") != std::string::npos &&
|
||||
src.find("[target_lang_code]") != std::string::npos) {
|
||||
return common_chat_params_init_translate_gemma(tmpl, params);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
@@ -3174,3 +3188,9 @@ common_chat_params common_chat_templates_apply(
|
||||
? common_chat_templates_apply_jinja(tmpls, inputs)
|
||||
: common_chat_templates_apply_legacy(tmpls, inputs);
|
||||
}
|
||||
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates) {
|
||||
GGML_ASSERT(chat_templates != nullptr);
|
||||
GGML_ASSERT(chat_templates->template_default != nullptr);
|
||||
return chat_templates->template_default->caps.to_map();
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
struct common_chat_templates;
|
||||
|
||||
struct common_chat_tool_call {
|
||||
@@ -26,6 +28,11 @@ struct common_chat_msg_content_part {
|
||||
std::string type;
|
||||
std::string text;
|
||||
|
||||
// TODO @ngxson : no known chat templates support reasoning_content in content parts yet
|
||||
// this can be useful for models with interleaved thinking (like Kimi-K2)
|
||||
// if you see any templates explicitly support this, please ping me
|
||||
// std::string reasoning_content;
|
||||
|
||||
bool operator==(const common_chat_msg_content_part & other) const {
|
||||
return type == other.type && text == other.text;
|
||||
}
|
||||
@@ -40,7 +47,7 @@ struct common_chat_msg {
|
||||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
|
||||
template <class T> T to_json_oaicompat() const;
|
||||
nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const;
|
||||
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||
@@ -145,7 +152,7 @@ struct common_chat_templates_inputs {
|
||||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::map<std::string, std::string> chat_template_kwargs;
|
||||
@@ -165,14 +172,21 @@ struct common_chat_params {
|
||||
std::string parser;
|
||||
};
|
||||
|
||||
struct common_chat_syntax {
|
||||
// per-message parsing syntax
|
||||
// should be derived from common_chat_params
|
||||
struct common_chat_parser_params {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
|
||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool parse_tool_calls = true;
|
||||
common_peg_arena parser = {};
|
||||
common_chat_parser_params() = default;
|
||||
common_chat_parser_params(const common_chat_params & chat_params) {
|
||||
format = chat_params.format;
|
||||
thinking_forced_open = chat_params.thinking_forced_open;
|
||||
}
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
@@ -191,7 +205,7 @@ common_chat_templates_ptr common_chat_templates_init(
|
||||
const std::string & eos_token_override = "");
|
||||
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
|
||||
|
||||
struct common_chat_params common_chat_templates_apply(
|
||||
@@ -213,23 +227,25 @@ std::string common_chat_format_example(
|
||||
const std::map<std::string, std::string> & chat_template_kwargs);
|
||||
|
||||
const char* common_chat_format_name(common_chat_format format);
|
||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
|
||||
// used by arg and server
|
||||
const char * common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||
|
||||
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
|
||||
|
||||
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
||||
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);
|
||||
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
|
||||
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
|
||||
nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
|
||||
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
|
||||
// get template caps, useful for reporting to server /props endpoint
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates);
|
||||
|
||||
@@ -1097,7 +1097,10 @@ common_init_result::common_init_result(common_params & params) :
|
||||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
|
||||
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
|
||||
params.tensor_split,
|
||||
params.tensor_buft_overrides.data(),
|
||||
params.fit_params_target.data(),
|
||||
params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
}
|
||||
|
||||
@@ -1208,10 +1211,6 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||
return pimpl->lora;
|
||||
}
|
||||
|
||||
void common_init_result::free_context() {
|
||||
pimpl->context.reset();
|
||||
}
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
common_init_result_ptr res(new common_init_result(params));
|
||||
|
||||
|
||||
@@ -57,6 +57,8 @@ extern const char * LLAMA_COMMIT;
|
||||
extern const char * LLAMA_COMPILER;
|
||||
extern const char * LLAMA_BUILD_TARGET;
|
||||
|
||||
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
|
||||
|
||||
struct common_control_vector_load_info;
|
||||
|
||||
//
|
||||
@@ -162,6 +164,17 @@ enum common_params_sampling_config : uint64_t {
|
||||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||
};
|
||||
|
||||
enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
|
||||
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
@@ -249,6 +262,7 @@ struct common_params_speculative {
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
|
||||
@@ -259,6 +273,20 @@ struct common_params_speculative {
|
||||
struct cpu_params cpuparams_batch;
|
||||
|
||||
struct common_params_model model;
|
||||
|
||||
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
|
||||
|
||||
uint16_t ngram_size_n = 12; // ngram size for lookup
|
||||
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t ngram_check_rate = 1; // check rate for ngram lookup
|
||||
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
|
||||
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
|
||||
bool has_dft() const {
|
||||
return !model.path.empty() || !model.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
struct common_params_vocoder {
|
||||
@@ -284,6 +312,7 @@ struct common_params_diffusion {
|
||||
};
|
||||
|
||||
// reasoning API response format (not to be confused as chat template's reasoning format)
|
||||
// only used by server
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content`
|
||||
@@ -375,8 +404,6 @@ struct common_params {
|
||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
|
||||
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
|
||||
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
|
||||
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
||||
|
||||
// llama-debug specific options
|
||||
@@ -572,10 +599,6 @@ struct common_params {
|
||||
// return false from callback to abort model loading or true to continue
|
||||
llama_progress_callback load_progress_callback = NULL;
|
||||
void * load_progress_callback_user_data = NULL;
|
||||
|
||||
bool has_speculative() const {
|
||||
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
@@ -711,8 +734,6 @@ struct common_init_result {
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & lora();
|
||||
|
||||
void free_context();
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
|
||||
@@ -314,23 +314,26 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
|
||||
// download one single file from remote URL to local path
|
||||
// returns status code or -1 on error
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers) {
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers) {
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
|
||||
if (!bearer_token.empty()) {
|
||||
default_headers.insert({"Authorization", "Bearer " + bearer_token});
|
||||
}
|
||||
httplib::Headers headers;
|
||||
for (const auto & h : custom_headers) {
|
||||
default_headers.emplace(h.first, h.second);
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
cli.set_default_headers(default_headers);
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
}
|
||||
if (!bearer_token.empty()) {
|
||||
headers.emplace("Authorization", "Bearer " + bearer_token);
|
||||
}
|
||||
cli.set_default_headers(headers);
|
||||
|
||||
const bool file_exists = std::filesystem::exists(path);
|
||||
|
||||
@@ -437,10 +440,12 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
|
||||
const common_remote_params & params) {
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
|
||||
|
||||
for (const auto & header : params.headers) {
|
||||
headers.emplace(header.first, header.second);
|
||||
httplib::Headers headers;
|
||||
for (const auto & h : params.headers) {
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
}
|
||||
|
||||
if (params.timeout > 0) {
|
||||
|
||||
@@ -57,6 +57,17 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
|
||||
throw std::runtime_error("error: invalid URL format");
|
||||
}
|
||||
|
||||
#ifndef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
if (parts.scheme == "https") {
|
||||
throw std::runtime_error(
|
||||
"HTTPS is not supported. Please rebuild with one of:\n"
|
||||
" -DLLAMA_BUILD_BORINGSSL=ON\n"
|
||||
" -DLLAMA_BUILD_LIBRESSL=ON\n"
|
||||
" -DLLAMA_OPENSSL=ON (default, requires OpenSSL dev files installed)"
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
httplib::Client cli(parts.scheme + "://" + parts.host);
|
||||
|
||||
if (!parts.user.empty()) {
|
||||
|
||||
@@ -61,14 +61,23 @@ static void caps_print_stats(value & v, const std::string & path) {
|
||||
ops.c_str());
|
||||
}
|
||||
|
||||
std::map<std::string, bool> caps::to_map() const {
|
||||
return {
|
||||
{"requires_typed_content", requires_typed_content},
|
||||
{"supports_tools", supports_tools},
|
||||
{"supports_tool_calls", supports_tool_calls},
|
||||
{"supports_parallel_tool_calls", supports_parallel_tool_calls},
|
||||
{"supports_system_role", supports_system_role},
|
||||
{"supports_preserve_reasoning", supports_preserve_reasoning},
|
||||
};
|
||||
}
|
||||
|
||||
std::string caps::to_string() const {
|
||||
std::ostringstream ss;
|
||||
ss << "Caps(\n";
|
||||
ss << " requires_typed_content=" << requires_typed_content << "\n";
|
||||
ss << " supports_tools=" << supports_tools << "\n";
|
||||
ss << " supports_tool_calls=" << supports_tool_calls << "\n";
|
||||
ss << " supports_parallel_tool_calls=" << supports_parallel_tool_calls << "\n";
|
||||
ss << " supports_system_role=" << supports_system_role << "\n";
|
||||
for (const auto & [key, value] : to_map()) {
|
||||
ss << " " << key << "=" << (value ? "true" : "false") << "\n";
|
||||
}
|
||||
ss << ")";
|
||||
return ss.str();
|
||||
}
|
||||
@@ -229,6 +238,40 @@ caps caps_get(jinja::program & prog) {
|
||||
}
|
||||
);
|
||||
|
||||
// case: preserve reasoning content in chat history
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
// messages
|
||||
return json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", "Assistant message"},
|
||||
{"reasoning_content", "Reasoning content"}
|
||||
},
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array();
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
auto & content = messages->at(1)->at("reasoning_content");
|
||||
caps_print_stats(content, "messages[1].reasoning_content");
|
||||
if (content->stats.used) {
|
||||
result.supports_preserve_reasoning = true;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
JJ_DEBUG("%s\n", result.to_string().c_str());
|
||||
|
||||
return result;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "runtime.h"
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
namespace jinja {
|
||||
|
||||
@@ -11,14 +12,17 @@ struct caps {
|
||||
bool supports_tool_calls = true;
|
||||
bool supports_system_role = true;
|
||||
bool supports_parallel_tool_calls = true;
|
||||
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
|
||||
|
||||
bool requires_typed_content = false; // default: use string content
|
||||
|
||||
// for reporting on server
|
||||
std::map<std::string, bool> to_map() const;
|
||||
|
||||
// for debugging
|
||||
std::string to_string() const;
|
||||
};
|
||||
|
||||
caps caps_get(jinja::program & prog);
|
||||
void debug_print_caps(const caps & c);
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -268,8 +268,7 @@ value binary_expression::execute_impl(context & ctx) {
|
||||
// String in object
|
||||
if (is_val<value_string>(left_val) && is_val<value_object>(right_val)) {
|
||||
auto key = left_val->as_string().str();
|
||||
auto & obj = right_val->as_object();
|
||||
bool has_key = obj.find(key) != obj.end();
|
||||
bool has_key = right_val->has_key(key);
|
||||
if (op.value == "in") {
|
||||
return mk_val<value_bool>(has_key);
|
||||
} else if (op.value == "not in") {
|
||||
@@ -464,7 +463,7 @@ value for_statement::execute_impl(context & ctx) {
|
||||
std::vector<value> items;
|
||||
if (is_val<value_object>(iterable_val)) {
|
||||
JJ_DEBUG("%s", "For loop over object keys");
|
||||
auto & obj = iterable_val->as_object();
|
||||
auto & obj = iterable_val->as_ordered_object();
|
||||
for (auto & p : obj) {
|
||||
auto tuple = mk_val<value_array>();
|
||||
if (iterable_val->val_obj.is_key_numeric) {
|
||||
@@ -779,11 +778,8 @@ value member_expression::execute_impl(context & ctx) {
|
||||
throw std::runtime_error("Cannot access object with non-string: got " + property->type());
|
||||
}
|
||||
auto key = property->as_string().str();
|
||||
auto & obj = object->as_object();
|
||||
auto it = obj.find(key);
|
||||
if (it != obj.end()) {
|
||||
val = it->second;
|
||||
} else {
|
||||
val = object->at(key, val);
|
||||
if (is_val<value_undefined>(val)) {
|
||||
val = try_builtin_func(ctx, key, object, true);
|
||||
}
|
||||
JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
|
||||
@@ -809,7 +805,7 @@ value member_expression::execute_impl(context & ctx) {
|
||||
} else if (is_val<value_string>(property)) {
|
||||
auto key = property->as_string().str();
|
||||
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
|
||||
val = try_builtin_func(ctx, key, object);
|
||||
val = try_builtin_func(ctx, key, object, true);
|
||||
} else {
|
||||
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
|
||||
}
|
||||
@@ -818,7 +814,7 @@ value member_expression::execute_impl(context & ctx) {
|
||||
throw std::runtime_error("Cannot access property with non-string: got " + property->type());
|
||||
}
|
||||
auto key = property->as_string().str();
|
||||
val = try_builtin_func(ctx, key, object);
|
||||
val = try_builtin_func(ctx, key, object, true);
|
||||
}
|
||||
|
||||
if (ctx.is_get_stats && val && object && property) {
|
||||
|
||||
@@ -69,7 +69,7 @@ struct context {
|
||||
|
||||
context(const context & parent) : context() {
|
||||
// inherit variables (for example, when entering a new scope)
|
||||
auto & pvar = parent.env->as_object();
|
||||
auto & pvar = parent.env->as_ordered_object();
|
||||
for (const auto & pair : pvar) {
|
||||
set_val(pair.first, pair.second);
|
||||
}
|
||||
|
||||
@@ -776,19 +776,30 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
if (!is_val<value_array>(args.get_pos(0))) {
|
||||
throw raised_exception("join() first argument must be an array");
|
||||
}
|
||||
value val_delim = args.get_kwarg_or_pos("d", 1);
|
||||
value val_attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
if (!val_attribute->is_undefined()) {
|
||||
throw not_implemented_exception("array attribute join not implemented");
|
||||
}
|
||||
value val_delim = args.get_kwarg_or_pos("d", 1);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
const auto & arr = args.get_pos(0)->as_array();
|
||||
std::string delim = is_val<value_string>(val_delim) ? val_delim->as_string().str() : "";
|
||||
const bool attr_is_int = is_val<value_int>(attribute);
|
||||
if (!attribute->is_undefined() && !is_val<value_string>(attribute) && !attr_is_int) {
|
||||
throw raised_exception("join() attribute must be string or integer");
|
||||
}
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str();
|
||||
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
|
||||
std::string result;
|
||||
for (size_t i = 0; i < arr.size(); ++i) {
|
||||
if (!is_val<value_string>(arr[i]) && !is_val<value_int>(arr[i]) && !is_val<value_float>(arr[i])) {
|
||||
value val_arr = arr[i];
|
||||
if (!attribute->is_undefined()) {
|
||||
if (attr_is_int && is_val<value_array>(val_arr)) {
|
||||
val_arr = val_arr->at(attr_int);
|
||||
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(val_arr)) {
|
||||
val_arr = val_arr->at(attr_name);
|
||||
}
|
||||
}
|
||||
if (!is_val<value_string>(val_arr) && !is_val<value_int>(val_arr) && !is_val<value_float>(val_arr)) {
|
||||
throw raised_exception("join() can only join arrays of strings or numerics");
|
||||
}
|
||||
result += arr[i]->as_string().str();
|
||||
result += val_arr->as_string().str();
|
||||
if (i < arr.size() - 1) {
|
||||
result += delim;
|
||||
}
|
||||
@@ -803,26 +814,30 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"map", [](const func_args & args) -> value {
|
||||
args.ensure_count(2, 3);
|
||||
args.ensure_count(2);
|
||||
if (!is_val<value_array>(args.get_pos(0))) {
|
||||
throw raised_exception("map: first argument must be an array");
|
||||
}
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 1);
|
||||
if (is_val<value_int>(attribute)) {
|
||||
throw not_implemented_exception("map: integer attribute not implemented");
|
||||
if (!is_val<value_kwarg>(args.get_args().at(1))) {
|
||||
throw not_implemented_exception("map: filter-mapping not implemented");
|
||||
}
|
||||
if (!is_val<value_string>(attribute)) {
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 1);
|
||||
const bool attr_is_int = is_val<value_int>(attribute);
|
||||
if (!is_val<value_string>(attribute) && !attr_is_int) {
|
||||
throw raised_exception("map: attribute must be string or integer");
|
||||
}
|
||||
std::string attr_name = attribute->as_string().str();
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
const std::string attr_name = attribute->as_string().str();
|
||||
value default_val = args.get_kwarg("default", mk_val<value_undefined>());
|
||||
auto out = mk_val<value_array>();
|
||||
auto arr = args.get_pos(0)->as_array();
|
||||
for (const auto & item : arr) {
|
||||
if (!is_val<value_object>(item)) {
|
||||
throw raised_exception("map: item is not an object");
|
||||
value attr_val;
|
||||
if (attr_is_int) {
|
||||
attr_val = is_val<value_array>(item) ? item->at(attr_int, default_val) : default_val;
|
||||
} else {
|
||||
attr_val = is_val<value_object>(item) ? item->at(attr_name, default_val) : default_val;
|
||||
}
|
||||
value attr_val = item->at(attr_name, default_val);
|
||||
out->push_back(attr_val);
|
||||
}
|
||||
return out;
|
||||
@@ -848,29 +863,35 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
return arr_editable->pop_at(index);
|
||||
}},
|
||||
{"sort", [](const func_args & args) -> value {
|
||||
args.ensure_count(1, 3);
|
||||
args.ensure_count(1, 4);
|
||||
if (!is_val<value_array>(args.get_pos(0))) {
|
||||
throw raised_exception("sort: first argument must be an array");
|
||||
}
|
||||
bool reverse = args.get_kwarg("reverse", mk_val<value_undefined>())->as_bool();
|
||||
value attribute = args.get_kwarg("attribute", mk_val<value_undefined>());
|
||||
std::string attr = attribute->is_undefined() ? "" : attribute->as_string().str();
|
||||
value val_reverse = args.get_kwarg_or_pos("reverse", 1);
|
||||
value val_case = args.get_kwarg_or_pos("case_sensitive", 2);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 3);
|
||||
// FIXME: sorting is currently always case sensitive
|
||||
//const bool case_sensitive = val_case->as_bool(); // undefined == false
|
||||
const bool reverse = val_reverse->as_bool(); // undefined == false
|
||||
const bool attr_is_int = is_val<value_int>(attribute);
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
|
||||
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
|
||||
std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) {
|
||||
value val_a = a;
|
||||
value val_b = b;
|
||||
if (!attribute->is_undefined()) {
|
||||
if (!is_val<value_object>(a) || !is_val<value_object>(b)) {
|
||||
throw raised_exception("sort: items are not objects");
|
||||
if (attr_is_int && is_val<value_array>(a) && is_val<value_array>(b)) {
|
||||
val_a = a->at(attr_int);
|
||||
val_b = b->at(attr_int);
|
||||
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(a) && is_val<value_object>(b)) {
|
||||
val_a = a->at(attr_name);
|
||||
val_b = b->at(attr_name);
|
||||
} else {
|
||||
throw raised_exception("sort: unsupported object attribute comparison");
|
||||
}
|
||||
val_a = attr.empty() ? a : a->at(attr);
|
||||
val_b = attr.empty() ? b : b->at(attr);
|
||||
}
|
||||
if (reverse) {
|
||||
return value_compare(val_a, val_b, value_compare_op::gt);
|
||||
} else {
|
||||
return !value_compare(val_a, val_b, value_compare_op::gt);
|
||||
}
|
||||
return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt);
|
||||
});
|
||||
return mk_val<value_array>(arr);
|
||||
}},
|
||||
@@ -908,18 +929,13 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
if (args.count() == 3) {
|
||||
default_val = args.get_pos(2);
|
||||
}
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const value obj = args.get_pos(0);
|
||||
std::string key = args.get_pos(1)->as_string().str();
|
||||
auto it = obj.find(key);
|
||||
if (it != obj.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return default_val;
|
||||
}
|
||||
return obj->at(key, default_val);
|
||||
}},
|
||||
{"keys", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
auto result = mk_val<value_array>();
|
||||
for (const auto & pair : obj) {
|
||||
result->push_back(mk_val<value_string>(pair.first));
|
||||
@@ -928,7 +944,7 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
}},
|
||||
{"values", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
auto result = mk_val<value_array>();
|
||||
for (const auto & pair : obj) {
|
||||
result->push_back(pair.second);
|
||||
@@ -937,7 +953,7 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
}},
|
||||
{"items", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
auto result = mk_val<value_array>();
|
||||
for (const auto & pair : obj) {
|
||||
auto item = mk_val<value_array>();
|
||||
@@ -951,7 +967,7 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
{"string", tojson},
|
||||
{"length", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
return mk_val<value_int>(static_cast<int64_t>(obj.size()));
|
||||
}},
|
||||
{"tojson", [](const func_args & args) -> value {
|
||||
@@ -964,21 +980,18 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
|
||||
value val_by = args.get_kwarg_or_pos("by", 2);
|
||||
value val_reverse = args.get_kwarg_or_pos("reverse", 3);
|
||||
// FIXME: sorting is case sensitive
|
||||
// FIXME: sorting is currently always case sensitive
|
||||
//const bool case_sensitive = val_case->as_bool(); // undefined == false
|
||||
const bool reverse = val_reverse->as_bool(); // undefined == false
|
||||
if (!val_by->is_undefined()) {
|
||||
throw not_implemented_exception("dictsort by key not implemented");
|
||||
}
|
||||
if (reverse) {
|
||||
throw not_implemented_exception("dictsort reverse not implemented");
|
||||
}
|
||||
value_t::map obj = val_input->val_obj; // copy
|
||||
std::sort(obj.ordered.begin(), obj.ordered.end(), [&](const auto & a, const auto & b) {
|
||||
return a.first < b.first;
|
||||
const bool by_value = is_val<value_string>(val_by) && val_by->as_string().str() == "value" ? true : false;
|
||||
auto result = mk_val<value_object>(val_input); // copy
|
||||
std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) {
|
||||
if (by_value) {
|
||||
return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt);
|
||||
} else {
|
||||
return reverse ? a.first > b.first : a.first < b.first;
|
||||
}
|
||||
});
|
||||
auto result = mk_val<value_object>();
|
||||
result->val_obj = std::move(obj);
|
||||
return result;
|
||||
}},
|
||||
{"join", [](const func_args &) -> value {
|
||||
@@ -992,6 +1005,7 @@ const func_builtins & value_none_t::get_builtins() const {
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"tojson", tojson},
|
||||
{"string", [](const func_args &) -> value { return mk_val<value_string>("None"); }}
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
@@ -1175,7 +1189,7 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val,
|
||||
}
|
||||
oss << "]";
|
||||
} else if (is_val<value_object>(val)) {
|
||||
const auto & obj = val->val_obj.ordered; // IMPORTANT: need to keep exact order
|
||||
const auto & obj = val->as_ordered_object(); // IMPORTANT: need to keep exact order
|
||||
oss << "{";
|
||||
if (!obj.empty()) {
|
||||
oss << newline();
|
||||
|
||||
@@ -146,7 +146,7 @@ struct value_t {
|
||||
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
|
||||
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
|
||||
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual const std::map<std::string, value> & as_object() const { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
|
||||
virtual bool is_none() const { return false; }
|
||||
virtual bool is_undefined() const { return false; }
|
||||
@@ -154,6 +154,9 @@ struct value_t {
|
||||
throw std::runtime_error("No builtins available for type " + type());
|
||||
}
|
||||
|
||||
virtual bool has_key(const std::string & key) {
|
||||
return val_obj.unordered.find(key) != val_obj.unordered.end();
|
||||
}
|
||||
virtual value & at(const std::string & key, value & default_val) {
|
||||
auto it = val_obj.unordered.find(key);
|
||||
if (it == val_obj.unordered.end()) {
|
||||
@@ -168,8 +171,20 @@ struct value_t {
|
||||
}
|
||||
return val_obj.unordered.at(key);
|
||||
}
|
||||
virtual value & at(size_t index) {
|
||||
if (index >= val_arr.size()) {
|
||||
virtual value & at(int64_t index, value & default_val) {
|
||||
if (index < 0) {
|
||||
index += val_arr.size();
|
||||
}
|
||||
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
|
||||
return default_val;
|
||||
}
|
||||
return val_arr[index];
|
||||
}
|
||||
virtual value & at(int64_t index) {
|
||||
if (index < 0) {
|
||||
index += val_arr.size();
|
||||
}
|
||||
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
|
||||
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
|
||||
}
|
||||
return val_arr[index];
|
||||
@@ -188,6 +203,9 @@ struct value_int_t : public value_t {
|
||||
virtual int64_t as_int() const override { return val_int; }
|
||||
virtual double as_float() const override { return static_cast<double>(val_int); }
|
||||
virtual string as_string() const override { return std::to_string(val_int); }
|
||||
virtual bool as_bool() const override {
|
||||
return val_int != 0;
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
};
|
||||
using value_int = std::shared_ptr<value_int_t>;
|
||||
@@ -204,6 +222,9 @@ struct value_float_t : public value_t {
|
||||
if (out.back() == '.') out.push_back('0'); // leave one zero if no decimals
|
||||
return out;
|
||||
}
|
||||
virtual bool as_bool() const override {
|
||||
return val_flt != 0.0;
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
};
|
||||
using value_float = std::shared_ptr<value_float_t>;
|
||||
@@ -296,11 +317,16 @@ struct value_object_t : public value_t {
|
||||
val_obj.insert(pair.first, pair.second);
|
||||
}
|
||||
}
|
||||
value_object_t(const std::vector<std::pair<std::string, value>> & obj) {
|
||||
for (const auto & pair : obj) {
|
||||
val_obj.insert(pair.first, pair.second);
|
||||
}
|
||||
}
|
||||
void insert(const std::string & key, const value & val) {
|
||||
val_obj.insert(key, val);
|
||||
}
|
||||
virtual std::string type() const override { return "Object"; }
|
||||
virtual const std::map<std::string, value> & as_object() const override { return val_obj.unordered; }
|
||||
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const override { return val_obj.ordered; }
|
||||
virtual bool as_bool() const override {
|
||||
return !val_obj.unordered.empty();
|
||||
}
|
||||
@@ -316,12 +342,12 @@ struct value_none_t : public value_t {
|
||||
virtual std::string type() const override { return "None"; }
|
||||
virtual bool is_none() const override { return true; }
|
||||
virtual bool as_bool() const override { return false; }
|
||||
virtual string as_string() const override { return string("None"); }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
};
|
||||
using value_none = std::shared_ptr<value_none_t>;
|
||||
|
||||
|
||||
struct value_undefined_t : public value_t {
|
||||
std::string hint; // for debugging, to indicate where undefined came from
|
||||
value_undefined_t(const std::string & h = "") : hint(h) {}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: use json_fwd.hpp when possible
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||
|
||||
@@ -192,12 +192,12 @@ void common_ngram_cache_draft(
|
||||
break;
|
||||
}
|
||||
|
||||
LOG(" - draft candidate: token=%d\n", drafted_token);
|
||||
LOG_DBG(" - draft candidate: token=%d\n", drafted_token);
|
||||
draft.push_back(drafted_token);
|
||||
}
|
||||
}
|
||||
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) {
|
||||
std::ofstream file_out(filename, std::ios::binary);
|
||||
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
|
||||
const common_ngram ngram = item.first;
|
||||
@@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil
|
||||
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
common_ngram_cache common_ngram_cache_load(std::string & filename) {
|
||||
common_ngram_cache common_ngram_cache_load(const std::string & filename) {
|
||||
std::ifstream hashmap_file(filename, std::ios::binary);
|
||||
if (!hashmap_file) {
|
||||
throw std::ifstream::failure("Unable to open file " + filename);
|
||||
|
||||
@@ -88,12 +88,12 @@ void common_ngram_cache_draft(
|
||||
// Save an ngram cache to a file.
|
||||
// ngram_cache: the ngram cache to save.
|
||||
// filename: the path under which to save the ngram cache.
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename);
|
||||
|
||||
// Load an ngram cache saved with common_ngram_cache_save.
|
||||
// filename: the path from which to load the ngram cache.
|
||||
// returns: an ngram cache containing the information saved to filename.
|
||||
common_ngram_cache common_ngram_cache_load(std::string & filename);
|
||||
common_ngram_cache common_ngram_cache_load(const std::string & filename);
|
||||
|
||||
// Merge two ngram caches.
|
||||
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
|
||||
|
||||
457
common/ngram-map.cpp
Normal file
457
common/ngram-map.cpp
Normal file
@@ -0,0 +1,457 @@
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "ngram-map.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
|
||||
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
|
||||
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
|
||||
std::ostringstream oss;
|
||||
oss << '[';
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (i > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << inp[start + i];
|
||||
}
|
||||
oss << ']';
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
// n-gram simple
|
||||
//
|
||||
|
||||
/**
|
||||
* Perform speculative generation using the model's own token history.
|
||||
* Searches for a matching pattern in the token history and returns draft tokens.
|
||||
*
|
||||
* @param state Current state of this implementation
|
||||
* @param tokens Token history to search in
|
||||
* @param sampled Last sampled token
|
||||
* @return Vector of draft tokens, empty if no matching pattern is found
|
||||
*/
|
||||
llama_tokens common_ngram_simple_draft(
|
||||
common_ngram_simple_state & state,
|
||||
const llama_tokens & tokens, llama_token sampled) {
|
||||
|
||||
// Simple implementation of self-speculative decoding without a draft model.
|
||||
//
|
||||
const size_t cur_len = tokens.size();
|
||||
// Only check every check_rate tokens to save compute
|
||||
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
|
||||
if (state.idx_last_check + state.config.check_rate > cur_len && cur_len > state.idx_last_check) {
|
||||
llama_tokens draft_tokens;
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
|
||||
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
|
||||
|
||||
// vector for tokens we want to verify.
|
||||
// return empty vector if there is no match.
|
||||
llama_tokens draft_tokens;
|
||||
|
||||
// We need at least n_draft_min + n_draft_max + 1 tokens.
|
||||
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
// pattern search
|
||||
llama_tokens pattern;
|
||||
pattern.reserve(n_draft_min);
|
||||
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
|
||||
pattern.push_back(tokens[j]);
|
||||
}
|
||||
pattern.push_back(sampled); // add the last token to the pattern
|
||||
|
||||
// We do a search in the token history.
|
||||
state.idx_last_check = cur_len;
|
||||
|
||||
size_t match_pos = 0; // we ignore position 0, position 0 == no match
|
||||
// search backwards, but skip the current match (we are currently there)
|
||||
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
|
||||
bool match = true;
|
||||
for (size_t k = 0; k < pattern.size(); ++k) {
|
||||
if (tokens[j + k] != pattern[k]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
match_pos = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_pos == 0) {
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
const size_t copy_max = std::min(
|
||||
n_draft_max,
|
||||
cur_len - (match_pos + n_draft_min)
|
||||
);
|
||||
if (copy_max < n_draft_min) {
|
||||
return draft_tokens;
|
||||
}
|
||||
LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n",
|
||||
__func__, cur_len,
|
||||
match_pos, pattern.size(), copy_max);
|
||||
|
||||
draft_tokens.reserve(copy_max);
|
||||
for (size_t j = 0; j < copy_max; ++j) {
|
||||
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
|
||||
}
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
|
||||
// n-gram map
|
||||
//
|
||||
|
||||
// maximum number of counted values of a ngram map value.
|
||||
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
|
||||
|
||||
void common_ngram_map_draft(common_ngram_map & map,
|
||||
const llama_tokens & inp, llama_token sampled,
|
||||
llama_tokens & draft) {
|
||||
// reset last key and value.
|
||||
map.last_draft_created = false;
|
||||
map.last_draft_key_idx = 0;
|
||||
map.last_draft_value_idx = 0;
|
||||
|
||||
const size_t cur_len = inp.size();
|
||||
const uint16_t n = map.size_key;
|
||||
const uint16_t m = map.size_value;
|
||||
if (cur_len < static_cast<size_t>(2 * n + m)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only check every check_rate tokens to save compute
|
||||
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
|
||||
if (map.idx_last_check + map.check_rate > cur_len && cur_len > map.idx_last_check) {
|
||||
return;
|
||||
}
|
||||
map.idx_last_check = cur_len;
|
||||
|
||||
// search pattern, the key n-gram
|
||||
std::vector<llama_token> key_tokens;
|
||||
key_tokens.reserve(n);
|
||||
for (size_t j = cur_len - n + 1; j < cur_len; ++j) {
|
||||
key_tokens.push_back(inp[j]);
|
||||
}
|
||||
key_tokens.push_back(sampled);
|
||||
|
||||
// search for the key in the map
|
||||
size_t match_pos = 0;
|
||||
for (size_t j = cur_len - n - m - 1; j > 0; --j) {
|
||||
bool match = true;
|
||||
for (size_t k = 0; k < n; ++k) {
|
||||
if (inp[j + k] != key_tokens[k]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
match_pos = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_pos > 0) {
|
||||
LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
|
||||
cur_len, n, m, key_tokens.size(), sampled, match_pos);
|
||||
}
|
||||
|
||||
if (match_pos == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We have a match, now we look for the statistics of the key.
|
||||
size_t key_offset = map.keys.size(); // offset in the map
|
||||
// We iterate through the std::vector<common_ngram_map_key> map->keys.
|
||||
for (size_t i = 0; i < map.keys.size(); ++i) {
|
||||
bool match = true;
|
||||
for (size_t j = 0; j < n; ++j) {
|
||||
if (inp[map.keys[i].key_idx + j] != key_tokens[j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
key_offset = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (key_offset == map.keys.size()) {
|
||||
// We create a new key-entry, it will get offset key_offset.
|
||||
common_ngram_map_key new_key;
|
||||
new_key.key_idx = match_pos;
|
||||
new_key.stat_idx = 0;
|
||||
new_key.key_num = 0;
|
||||
for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) {
|
||||
new_key.values[i].value_num = 0;
|
||||
new_key.values[i].n_accepted = m;
|
||||
}
|
||||
map.keys.push_back(new_key);
|
||||
}
|
||||
|
||||
// our key n-gram:
|
||||
common_ngram_map_key & curr_key = map.keys[key_offset];
|
||||
|
||||
// update number of key hits
|
||||
curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1,
|
||||
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
|
||||
|
||||
if (map.key_only) {
|
||||
// simple mode:
|
||||
// Fill in the draft with the m tokens following the key.
|
||||
// We work with value values[0] only.
|
||||
int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted);
|
||||
|
||||
for (int i = 0; i < n_draft_tokens; ++i) {
|
||||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, curr_key.key_num, draft.size());
|
||||
|
||||
map.last_draft_created = false;
|
||||
map.last_draft_key_idx = key_offset;
|
||||
map.last_draft_value_idx = 0; // value 0 is used for simple mode
|
||||
return;
|
||||
}
|
||||
|
||||
if (curr_key.key_num < map.min_hits) {
|
||||
// not enough hits to consider this a good draft
|
||||
LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__,
|
||||
key_offset, curr_key.key_num, map.min_hits);
|
||||
return;
|
||||
}
|
||||
|
||||
// complex mode: examine the different m-grams after this key n-gram.
|
||||
//
|
||||
|
||||
// determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram.
|
||||
for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) {
|
||||
// begins the key n-gram at index i?
|
||||
bool match_key = true;
|
||||
for (size_t k = 0; k < n; ++k) {
|
||||
if (inp[i + k] != key_tokens[k]) {
|
||||
match_key = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match_key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do we haven a existing value m-gram or a new one after the key at index i?
|
||||
size_t idx_begin_value_key = i + n;
|
||||
int idx_value = -1;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
size_t idx_begin_value_v = curr_key.values[v].value_idx;
|
||||
if (idx_begin_value_v == 0) {
|
||||
// We found an empty value slot => we found a new value m-gram after the key n-gram.
|
||||
curr_key.values[v].value_idx = idx_begin_value_key;
|
||||
curr_key.values[v].value_num = 0;
|
||||
curr_key.values[v].n_accepted = m;
|
||||
idx_value = v;
|
||||
break;
|
||||
}
|
||||
bool match = true;
|
||||
for (size_t j = 0; j < m; ++j) {
|
||||
if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
// We found an existing value m-gram after the key n-gram.
|
||||
idx_value = v;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (idx_value >= 0) {
|
||||
// We found a value m-gram of the key n-gram.
|
||||
curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1,
|
||||
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
|
||||
}
|
||||
}
|
||||
// the statistics are updated up to match_pos.
|
||||
curr_key.stat_idx = match_pos;
|
||||
|
||||
// Do we have a value we could use for the draft?
|
||||
uint16_t max_occur = 0;
|
||||
int slot_max = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
uint16_t curr_occur = curr_key.values[v].value_num;
|
||||
if (curr_occur > max_occur) {
|
||||
max_occur = curr_occur;
|
||||
slot_max = v;
|
||||
}
|
||||
}
|
||||
// What is sum of the other occurences?
|
||||
uint32_t sum_occur = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (v == slot_max) {
|
||||
continue;
|
||||
}
|
||||
uint16_t curr_occur = curr_key.values[v].value_num;
|
||||
sum_occur += curr_occur;
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
|
||||
key_offset,
|
||||
max_occur, sum_occur, slot_max,
|
||||
curr_key.values[0].value_idx, curr_key.values[0].value_num,
|
||||
curr_key.values[1].value_idx, curr_key.values[1].value_num,
|
||||
curr_key.values[2].value_idx, curr_key.values[2].value_num,
|
||||
curr_key.values[3].value_idx, curr_key.values[3].value_num
|
||||
);
|
||||
// Print the tokens of the four values (if idx != 0), use LOG_INF
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (curr_key.values[v].value_idx != 0) {
|
||||
LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (sum_occur > 0 && max_occur < 3 * sum_occur) {
|
||||
// The most frequent value is not much more frequent than the other values.
|
||||
// We do not use the draft.
|
||||
return;
|
||||
}
|
||||
|
||||
// We use the most frequent value values[slot_max] for the draft.
|
||||
// Fill in the draft with the m tokens following the key.
|
||||
int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted);
|
||||
|
||||
for (int i = 0; i < n_draft_tokens; ++i) {
|
||||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, slot_max,
|
||||
curr_key.key_num, draft.size());
|
||||
|
||||
map.last_draft_created = true;
|
||||
map.last_draft_key_idx = key_offset;
|
||||
map.last_draft_value_idx = slot_max; // value used for draft generation.
|
||||
}
|
||||
|
||||
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
|
||||
if (!map.last_draft_created) {
|
||||
return;
|
||||
}
|
||||
|
||||
// find the key and its chosen value.
|
||||
const size_t key_idx = map.last_draft_key_idx;
|
||||
const size_t val_idx = map.last_draft_value_idx;
|
||||
|
||||
// find key corresponding to key_idx.
|
||||
common_ngram_map_key & curr_key = map.keys[key_idx];
|
||||
// find value corresponding to val_idx.
|
||||
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
|
||||
|
||||
// update the value statistics
|
||||
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
|
||||
n_accepted, curr_value.n_accepted);
|
||||
curr_value.n_accepted = n_accepted;
|
||||
}
|
||||
|
||||
//
|
||||
// n-gram mod
|
||||
//
|
||||
|
||||
common_ngram_mod::common_ngram_mod(uint16_t m) : m(m) {
|
||||
int64_t n = 1;
|
||||
for (int32_t i = 0; i < N_MODS; ++i) {
|
||||
n *= mods[i];
|
||||
}
|
||||
|
||||
entries.resize(n);
|
||||
|
||||
const size_t size_bytes = entries.size() * sizeof(common_ngram_mod_entry);
|
||||
|
||||
LOG_INF("%s: size = %.3f MB\n", __func__, size_bytes / (1024.0 * 1024.0));
|
||||
}
|
||||
|
||||
void common_ngram_mod::add(const llama_token * tokens) {
|
||||
const uint64_t i = idx(tokens);
|
||||
|
||||
common_ngram_mod_entry & entry = entries[i];
|
||||
|
||||
if (entry.n_choices < COMMON_NGRAM_MOD_MAX_CHOICES) {
|
||||
entry.n_choices++;
|
||||
}
|
||||
|
||||
entry.choices[entry.head] = tokens[N_MODS];
|
||||
entry.head = (entry.head + 1) % COMMON_NGRAM_MOD_MAX_CHOICES;
|
||||
}
|
||||
|
||||
llama_token common_ngram_mod::get(const llama_token * tokens, int32_t offs) const {
|
||||
const uint64_t i = idx(tokens);
|
||||
|
||||
const common_ngram_mod_entry & entry = entries[i];
|
||||
|
||||
if (entry.n_choices == 0) {
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
const int32_t k = (offs + entry.head) % entry.n_choices;
|
||||
|
||||
return entry.choices[k];
|
||||
}
|
||||
|
||||
uint64_t common_ngram_mod::idx(const llama_token * tokens) {
|
||||
uint64_t rh = 0;
|
||||
uint64_t res = 0;
|
||||
for (uint64_t i = 0; i < N_MODS; ++i) {
|
||||
rh = rh * 31 + tokens[i];
|
||||
res = res * mods[i] + (rh % mods[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void common_ngram_mod_draft(
|
||||
common_ngram_mod & mod,
|
||||
const llama_tokens & inp,
|
||||
llama_token sampled,
|
||||
llama_tokens & draft) {
|
||||
const size_t N_MODS = common_ngram_mod::N_MODS;
|
||||
|
||||
const size_t cur_len = inp.size();
|
||||
if (cur_len < N_MODS) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (mod.n_calls++ % 64 == 0) {
|
||||
const size_t n_start = (256*(mod.n_calls/64)) % GGML_PAD(cur_len, 256);
|
||||
for (size_t i = 0; i < 256 && n_start + i < cur_len - N_MODS; ++i) {
|
||||
mod.add(inp.data() + n_start + i);
|
||||
}
|
||||
}
|
||||
|
||||
draft.resize(N_MODS + mod.m);
|
||||
for (size_t i = 0; i < N_MODS - 1; ++i) {
|
||||
draft[i] = inp[cur_len - N_MODS + 1 + i];
|
||||
}
|
||||
draft[N_MODS - 1] = sampled;
|
||||
|
||||
for (size_t i = 0; i < mod.m; ++i) {
|
||||
const llama_token token = mod.get(draft.data() + i, cur_len + i);
|
||||
if (token == LLAMA_TOKEN_NULL) {
|
||||
draft.clear();
|
||||
return;
|
||||
}
|
||||
draft[N_MODS + i] = token;
|
||||
}
|
||||
|
||||
// only return the m tokens that were drafted
|
||||
for (size_t i = 0; i < mod.m; ++i) {
|
||||
draft[i] = draft[N_MODS + i];
|
||||
}
|
||||
draft.resize(mod.m);
|
||||
}
|
||||
143
common/ngram-map.h
Normal file
143
common/ngram-map.h
Normal file
@@ -0,0 +1,143 @@
|
||||
#pragma once
|
||||
//
|
||||
// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams
|
||||
//
|
||||
// These structures are used to do a lookup of n-grams followed by m-grams in token history.
|
||||
//
|
||||
// There are two algorithms implemented:
|
||||
// 1. ngram_simple: lookup of n-grams followed by m-grams in token history.
|
||||
// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
|
||||
// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
|
||||
//
|
||||
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// n-gram simple
|
||||
//
|
||||
|
||||
// config of n-gram simple.
|
||||
struct common_ngram_simple_config {
|
||||
uint16_t size_ngram; // size of n-grams to lookup in self-mode
|
||||
uint16_t size_mgram; // size of m-grams to draft in self-mode
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
};
|
||||
|
||||
// current state (and config) of n-gram simple.
|
||||
struct common_ngram_simple_state {
|
||||
common_ngram_simple_config config;
|
||||
|
||||
size_t idx_last_check = 0; // index of last check in context history (mutable)
|
||||
|
||||
common_ngram_simple_state(const common_ngram_simple_config & config)
|
||||
: config(config) {}
|
||||
};
|
||||
|
||||
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
|
||||
// state: the ngram simple state to search in.
|
||||
// inp: the tokens generated so far.
|
||||
// sampled: the token that was just sampled.
|
||||
// draft: vector to store the draft tokens, initially empty.
|
||||
llama_tokens common_ngram_simple_draft(
|
||||
common_ngram_simple_state & state,
|
||||
const llama_tokens & tokens, llama_token sampled);
|
||||
|
||||
|
||||
// n-gram map
|
||||
//
|
||||
|
||||
// maximum number of m-gram values stored for each key n-gram.
|
||||
#define COMMON_NGRAM_MAX_VALUES 4
|
||||
|
||||
// statistics of a m-gram after a known n-gram
|
||||
struct common_ngram_map_value {
|
||||
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
|
||||
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
|
||||
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
|
||||
};
|
||||
|
||||
// statistics of a n-gram
|
||||
struct common_ngram_map_key {
|
||||
size_t key_idx; // index of key n-gram in token-history
|
||||
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
|
||||
|
||||
uint16_t key_num; // number of occurences of this key n-gram in token-history
|
||||
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
|
||||
};
|
||||
|
||||
// map from n-grams to following m-grams in token-history
|
||||
struct common_ngram_map {
|
||||
uint16_t size_key; // size of key n-grams
|
||||
uint16_t size_value; // size of value m-grams
|
||||
|
||||
bool key_only; // true if only key n-grams are used, no values.
|
||||
|
||||
// first draft: vector only, no map.
|
||||
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
uint16_t min_hits; // minimum number of key hits to consider a draft
|
||||
|
||||
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
|
||||
uint16_t check_rate, uint16_t min_hits)
|
||||
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
|
||||
check_rate(check_rate), min_hits(min_hits) {}
|
||||
|
||||
bool last_draft_created = false; // true if a draft was created at last call.
|
||||
size_t last_draft_key_idx = 0; // index of last key used for draft generation.
|
||||
uint16_t last_draft_value_idx = 0; // index of last value used for draft generation.
|
||||
|
||||
size_t idx_last_check = 0; // index of last check in context history
|
||||
};
|
||||
|
||||
|
||||
// Searches for the n-gram in the history and checks whether a draft sequence should be generated.
|
||||
// map: the ngram map to search in.
|
||||
// inp: the tokens generated so far.
|
||||
// sampled: the token that was just sampled.
|
||||
// draft: vector to store the draft tokens, initially empty.
|
||||
void common_ngram_map_draft(
|
||||
common_ngram_map & map,
|
||||
const llama_tokens & inp, llama_token sampled,
|
||||
llama_tokens & draft);
|
||||
|
||||
// Update the statistics of a value after a draft was processed.
|
||||
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted);
|
||||
|
||||
//
|
||||
// n-gram mod
|
||||
//
|
||||
|
||||
#define COMMON_NGRAM_MOD_MAX_CHOICES 4
|
||||
|
||||
struct common_ngram_mod_entry {
|
||||
uint32_t head = 0;
|
||||
uint32_t n_choices = 0;
|
||||
|
||||
llama_token choices[COMMON_NGRAM_MOD_MAX_CHOICES];
|
||||
};
|
||||
|
||||
struct common_ngram_mod {
|
||||
common_ngram_mod(uint16_t m);
|
||||
|
||||
void add(const llama_token * tokens);
|
||||
llama_token get(const llama_token * tokens, int32_t offs) const;
|
||||
|
||||
uint64_t n_calls = 0;
|
||||
|
||||
uint16_t m;
|
||||
|
||||
std::vector<common_ngram_mod_entry> entries;
|
||||
|
||||
static constexpr int32_t N_MODS = 17;
|
||||
static constexpr int32_t mods[N_MODS] = { 2, 1, 1, 1, 8, 1, 1, 1, 16, 1, 1, 1, 32, 1, 1, 1, 64, };
|
||||
|
||||
static uint64_t idx(const llama_token * tokens);
|
||||
};
|
||||
|
||||
void common_ngram_mod_draft(
|
||||
common_ngram_mod & mod,
|
||||
const llama_tokens & inp,
|
||||
llama_token sampled,
|
||||
llama_tokens & draft);
|
||||
@@ -1,97 +1,54 @@
|
||||
#include "speculative.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "ngram-cache.h"
|
||||
#include "ngram-map.h"
|
||||
#include "sampling.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <map>
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
||||
struct common_speculative {
|
||||
struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||
struct llama_context * ctx_dft;
|
||||
struct common_sampler * smpl;
|
||||
|
||||
llama_batch batch;
|
||||
llama_tokens prompt_dft;
|
||||
bool vocab_dft_compatible = true; // whether retokenization is needed
|
||||
std::map<std::string, std::string> tgt_dft_replacements = {};
|
||||
const std::vector<enum common_speculative_type> common_speculative_types = {
|
||||
COMMON_SPECULATIVE_TYPE_NONE,
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT,
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct llama_context * ctx_tgt,
|
||||
struct llama_context * ctx_dft) {
|
||||
auto * result = new common_speculative {
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .ctx_dft = */ ctx_dft,
|
||||
/* .smpl = */ nullptr,
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||
/* .prompt_dft = */ {},
|
||||
/* .vocab_dft_compatible = */ false,
|
||||
};
|
||||
const std::map<std::string, enum common_speculative_type> common_speculative_type_from_name_map = {
|
||||
{"none", COMMON_SPECULATIVE_TYPE_NONE},
|
||||
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
|
||||
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
|
||||
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
||||
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
||||
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
|
||||
{"ngram_map_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD},
|
||||
{"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}
|
||||
};
|
||||
|
||||
// TODO: optimize or pass from outside?
|
||||
#if 0
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
struct common_speculative_config {
|
||||
common_speculative_type type;
|
||||
common_params_speculative params;
|
||||
|
||||
params.top_k = 40;
|
||||
params.top_p = 0.9;
|
||||
|
||||
params.samplers = {
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
COMMON_SAMPLER_TYPE_TOP_P,
|
||||
COMMON_SAMPLER_TYPE_INFILL,
|
||||
};
|
||||
|
||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
#else
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
|
||||
params.top_k = 10;
|
||||
|
||||
params.samplers = {
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
};
|
||||
|
||||
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
#endif
|
||||
|
||||
result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft);
|
||||
LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_speculative_free(struct common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
common_sampler_free(spec->smpl);
|
||||
|
||||
llama_batch_free(spec->batch);
|
||||
|
||||
delete spec;
|
||||
}
|
||||
|
||||
bool common_speculative_are_compatible(
|
||||
const struct llama_context * ctx_tgt,
|
||||
const struct llama_context * ctx_dft) {
|
||||
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
|
||||
const struct llama_model * model_dft = llama_get_model(ctx_dft);
|
||||
common_speculative_config(common_speculative_type t,
|
||||
const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {}
|
||||
};
|
||||
|
||||
static bool common_speculative_are_compatible(
|
||||
const struct llama_model * model_tgt,
|
||||
const struct llama_model * model_dft) {
|
||||
const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||
const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
||||
|
||||
@@ -134,11 +91,12 @@ bool common_speculative_are_compatible(
|
||||
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
||||
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
||||
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
|
||||
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
|
||||
common_token_to_piece(ctx_tgt, i).c_str(),
|
||||
common_token_to_piece(ctx_dft, i).c_str());
|
||||
common_token_to_piece(vocab_tgt, i).c_str(),
|
||||
common_token_to_piece(vocab_dft, i).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -147,50 +105,437 @@ bool common_speculative_are_compatible(
|
||||
return true;
|
||||
}
|
||||
|
||||
void common_speculative_add_replacement_tgt_dft(
|
||||
struct common_speculative * spec,
|
||||
const char *source, const char *dest) {
|
||||
spec->tgt_dft_replacements[source] = dest;
|
||||
// state of an implementation of speculative decoding
|
||||
//
|
||||
// each implementation has a unique type and a state that is implementation-specific
|
||||
// in a subclass of common_speculative_state
|
||||
struct common_speculative_state {
|
||||
const enum common_speculative_type type;
|
||||
|
||||
size_t drafts_call_count = 0; // number of times this implementation was called.
|
||||
size_t drafts_generated_count = 0; // number of times a draft or part was generated by this implementation.
|
||||
size_t drafts_accepted_count = 0; // number of times a draft or part was accepted by the target model.
|
||||
size_t drafts_generated_tokens = 0; // number of tokens generated by this implementation.
|
||||
size_t drafts_accepted_tokens = 0; // number of tokens accepted by the target model.
|
||||
|
||||
// TODO: track performance of most recent calls
|
||||
const bool gen_perf = true; // whether to generate performance stats.
|
||||
|
||||
int64_t gen_duration_us = 0; // total time spent in this implementation in microseconds.
|
||||
|
||||
virtual ~common_speculative_state() = default;
|
||||
|
||||
common_speculative_state(enum common_speculative_type type) : type(type) {}
|
||||
};
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||
struct llama_context * ctx_dft;
|
||||
|
||||
struct common_sampler * smpl;
|
||||
|
||||
llama_batch batch;
|
||||
llama_tokens prompt_dft;
|
||||
|
||||
bool vocab_cmpt = true; // whether retokenization is needed
|
||||
std::unordered_map<std::string, std::string> vocab_map;
|
||||
|
||||
common_speculative_state_draft(
|
||||
enum common_speculative_type type,
|
||||
struct llama_context * ctx_tgt,
|
||||
struct llama_context * ctx_dft,
|
||||
const std::vector<std::pair<std::string, std::string>> & replacements)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
, ctx_dft(ctx_dft)
|
||||
{
|
||||
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||
smpl = nullptr;
|
||||
|
||||
// TODO: optimize or pass from outside?
|
||||
// {
|
||||
// common_params_sampling params;
|
||||
// params.no_perf = false;
|
||||
//
|
||||
// params.top_k = 40;
|
||||
// params.top_p = 0.9;
|
||||
//
|
||||
// params.samplers = {
|
||||
// COMMON_SAMPLER_TYPE_TOP_K,
|
||||
// COMMON_SAMPLER_TYPE_TOP_P,
|
||||
// COMMON_SAMPLER_TYPE_INFILL,
|
||||
// };
|
||||
//
|
||||
// result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
// }
|
||||
{
|
||||
common_params_sampling params;
|
||||
params.no_perf = false;
|
||||
params.top_k = 10;
|
||||
params.samplers = {
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
};
|
||||
|
||||
smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||
}
|
||||
|
||||
vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
|
||||
LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt);
|
||||
|
||||
if (!vocab_cmpt) {
|
||||
LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n");
|
||||
|
||||
for (const auto & pair : replacements) {
|
||||
vocab_map[pair.first] = pair.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~common_speculative_state_draft() override {
|
||||
llama_perf_context_print(ctx_dft);
|
||||
|
||||
llama_free(ctx_dft);
|
||||
|
||||
common_sampler_free(smpl);
|
||||
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_eagle3 : public common_speculative_state {
|
||||
common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {}
|
||||
};
|
||||
|
||||
// state of self-speculation (simple implementation, not ngram-map)
|
||||
struct common_speculative_state_ngram_simple : public common_speculative_state {
|
||||
|
||||
common_ngram_simple_state state;
|
||||
|
||||
common_speculative_state_ngram_simple(
|
||||
enum common_speculative_type type,
|
||||
common_ngram_simple_state state)
|
||||
: common_speculative_state(type), state(state) {}
|
||||
};
|
||||
|
||||
struct common_speculative_state_ngram_map_k : public common_speculative_state {
|
||||
// draft ngram map for speculative decoding without draft model
|
||||
common_ngram_map map;
|
||||
|
||||
common_speculative_state_ngram_map_k(
|
||||
enum common_speculative_type type,
|
||||
common_ngram_map map)
|
||||
: common_speculative_state(type), map(std::move(map)) {}
|
||||
};
|
||||
|
||||
struct common_speculative_state_ngram_map_k4v : public common_speculative_state_ngram_map_k {
|
||||
common_speculative_state_ngram_map_k4v(
|
||||
enum common_speculative_type type,
|
||||
common_ngram_map map)
|
||||
: common_speculative_state_ngram_map_k(type, std::move(map)) {}
|
||||
};
|
||||
|
||||
struct common_speculative_state_ngram_mod : public common_speculative_state {
|
||||
common_ngram_mod mod;
|
||||
|
||||
common_speculative_state_ngram_mod(
|
||||
enum common_speculative_type type,
|
||||
common_ngram_mod mod)
|
||||
: common_speculative_state(type), mod(std::move(mod)) {}
|
||||
};
|
||||
|
||||
struct common_speculative_state_ngram_cache : public common_speculative_state {
|
||||
uint16_t n_draft;
|
||||
bool save_dynamic;
|
||||
bool save_static;
|
||||
|
||||
common_ngram_cache ngram_cache_context;
|
||||
common_ngram_cache ngram_cache_dynamic;
|
||||
common_ngram_cache ngram_cache_static;
|
||||
|
||||
size_t cache_size = 0; // number of tokens in n-gram cache
|
||||
|
||||
common_speculative_state_ngram_cache(
|
||||
const enum common_speculative_type type,
|
||||
const std::string & path_static,
|
||||
const std::string & path_dynamic,
|
||||
uint16_t n_draft,
|
||||
bool save_dynamic,
|
||||
bool save_static)
|
||||
: common_speculative_state(type)
|
||||
, n_draft(n_draft)
|
||||
, save_dynamic(save_dynamic)
|
||||
, save_static(save_static)
|
||||
{
|
||||
if (!path_static.empty()) {
|
||||
try {
|
||||
ngram_cache_static = common_ngram_cache_load(path_static);
|
||||
} catch (...) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", path_static.c_str());
|
||||
GGML_ABORT("Couldn't read static lookup cache");
|
||||
}
|
||||
}
|
||||
|
||||
if (!path_dynamic.empty()) {
|
||||
try {
|
||||
ngram_cache_dynamic = common_ngram_cache_load(path_dynamic);
|
||||
} catch (...) {
|
||||
LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
|
||||
GGML_ABORT("Couldn't read dynamic lookup cache");
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative {
|
||||
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
|
||||
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
|
||||
};
|
||||
|
||||
static common_ngram_map get_common_ngram_map(const common_speculative_config & config) {
|
||||
uint16_t size_key = config.params.ngram_size_n;
|
||||
uint16_t size_value = config.params.ngram_size_m;
|
||||
bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
|
||||
uint16_t check_rate = config.params.ngram_check_rate;
|
||||
uint16_t min_hits = config.params.ngram_min_hits;
|
||||
|
||||
return common_ngram_map(size_key, size_value, key_only, check_rate, min_hits);
|
||||
}
|
||||
|
||||
static struct common_speculative_state_ngram_cache create_state_ngram_cache(
|
||||
const std::string & path_static, const std::string & path_dynamic,
|
||||
const common_speculative_config & config) {
|
||||
uint16_t n_draft = 8; // TODO get from config?
|
||||
|
||||
// TODO bool param in common/common.h to set save_static/save_dynamic?
|
||||
bool save_static = false;
|
||||
bool save_dynamic = false;
|
||||
|
||||
common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic);
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
std::string common_speculative_type_name_str() {
|
||||
std::string result;
|
||||
for (size_t i = 0; i < common_speculative_types.size(); i++) {
|
||||
if (i > 0) {
|
||||
result += ", ";
|
||||
}
|
||||
result += common_speculative_type_to_str(common_speculative_types[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type) {
|
||||
switch (type) {
|
||||
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD: return "ngram_map_mod";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
enum common_speculative_type common_speculative_type_from_name(const std::string & name) {
|
||||
const auto it = common_speculative_type_from_name_map.find(name);
|
||||
if (it == common_speculative_type_from_name_map.end()) {
|
||||
return COMMON_SPECULATIVE_TYPE_COUNT;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// initialization of the speculative decoding system
|
||||
//
|
||||
struct common_speculative * common_speculative_init(
|
||||
const struct common_params_speculative & params,
|
||||
struct llama_context * ctx_tgt,
|
||||
const struct llama_context_params & cparams_dft,
|
||||
struct llama_model * model_dft) {
|
||||
llama_context * ctx_dft = nullptr;
|
||||
if (model_dft) {
|
||||
ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
||||
if (ctx_dft == nullptr) {
|
||||
LOG_ERR("%s", "failed to create draft context\n");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the implementations to use based on the config and their order of preference
|
||||
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
|
||||
{
|
||||
bool has_draft = !params.model.path.empty();
|
||||
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
|
||||
|
||||
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
|
||||
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
|
||||
bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
|
||||
bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
|
||||
bool has_ngram_map_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD);
|
||||
|
||||
// In a more complex implementation we could use the same implementation but with different parameters.
|
||||
// This was initially used in PR-18471 but removed to simplify the code.
|
||||
if (has_ngram_simple) {
|
||||
// This implementation can guess a lot of tokens without any draft model.
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params));
|
||||
}
|
||||
if (has_ngram_map_k) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, params));
|
||||
}
|
||||
if (has_ngram_map_k4v) {
|
||||
// This implementation can guess tokens with high acceptance rate but is more expensive.
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params));
|
||||
}
|
||||
if (has_ngram_map_mod) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD, params));
|
||||
}
|
||||
if (has_ngram_cache) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
|
||||
}
|
||||
if (has_draft) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
|
||||
}
|
||||
if (has_draft_eagle3) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<common_speculative_state>> implementations = {};
|
||||
|
||||
for (const common_speculative_config & config : configs) {
|
||||
LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str());
|
||||
switch (config.type) {
|
||||
case COMMON_SPECULATIVE_TYPE_NONE:
|
||||
break;
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: {
|
||||
implementations.push_back(std::make_unique<common_speculative_state_draft>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .ctx_dft = */ ctx_dft,
|
||||
/* .replacements = */ params.replacements
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
|
||||
implementations.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
|
||||
common_ngram_map ngram_map = get_common_ngram_map(config);
|
||||
|
||||
uint16_t ngram_size_key = ngram_map.size_key;
|
||||
uint16_t mgram_size_value = ngram_map.size_value;
|
||||
uint16_t check_rate = ngram_map.check_rate;
|
||||
|
||||
auto config_simple = common_ngram_simple_config{
|
||||
/* .size_ngram = */ ngram_size_key,
|
||||
/* .size_mgram = */ mgram_size_value,
|
||||
/* .check_rate = */ check_rate
|
||||
};
|
||||
auto state = std::make_unique<common_speculative_state_ngram_simple>(
|
||||
/* .type = */ config.type,
|
||||
/* .state = */ common_ngram_simple_state(config_simple)
|
||||
);
|
||||
implementations.push_back(std::move(state));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: {
|
||||
implementations.push_back(std::make_unique<common_speculative_state_ngram_map_k>(
|
||||
(config.type),
|
||||
get_common_ngram_map(config)
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
|
||||
implementations.push_back(std::make_unique<common_speculative_state_ngram_map_k4v>(
|
||||
(config.type),
|
||||
get_common_ngram_map(config)
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD: {
|
||||
common_ngram_mod mod(config.params.ngram_size_m);
|
||||
implementations.push_back(std::make_unique<common_speculative_state_ngram_mod>(
|
||||
(config.type),
|
||||
std::move(mod)
|
||||
));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
|
||||
auto state = create_state_ngram_cache(
|
||||
params.lookup_cache_static, params.lookup_cache_dynamic, config);
|
||||
implementations.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (implementations.empty()) {
|
||||
LOG_WRN("%s", "no implementations specified for speculative decoding\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto * result = new common_speculative {
|
||||
/* .impls = */ std::move(implementations)
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_speculative_free(struct common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
delete spec;
|
||||
}
|
||||
|
||||
static std::string replace_to_dft(
|
||||
struct common_speculative * spec,
|
||||
const std::string& input) {
|
||||
struct common_speculative_state_draft * spec,
|
||||
const std::string & input) {
|
||||
std::string result = input;
|
||||
for (const auto & pair : spec->tgt_dft_replacements) {
|
||||
|
||||
for (const auto & pair : spec->vocab_map) {
|
||||
size_t pos = result.find(pair.first);
|
||||
while (pos != std::string::npos) {
|
||||
result.replace(pos, pair.first.length(), pair.second);
|
||||
pos = result.find(pair.first, pos + pair.second.length());
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string replace_to_tgt(
|
||||
struct common_speculative * spec,
|
||||
struct common_speculative_state_draft * spec,
|
||||
const std::string& input) {
|
||||
std::string result = input;
|
||||
for (const auto& pair : spec->tgt_dft_replacements) {
|
||||
|
||||
for (const auto & pair : spec->vocab_map) {
|
||||
size_t pos = result.find(pair.second);
|
||||
while (pos != std::string::npos) {
|
||||
result.replace(pos, pair.second.length(), pair.first);
|
||||
pos = result.find(pair.second, pos + pair.first.length());
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
static llama_tokens common_speculative_use_draft_model(
|
||||
struct common_speculative_state_draft * spec,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
|
||||
const llama_tokens & prompt_tgt, // specified in target model vocab
|
||||
llama_token id_last) {
|
||||
auto & batch = spec->batch;
|
||||
auto & ctx_tgt = spec->ctx_tgt;
|
||||
auto & ctx_dft = spec->ctx_dft;
|
||||
auto & smpl = spec->smpl;
|
||||
auto & batch = spec->batch;
|
||||
auto & ctx_tgt = spec->ctx_tgt;
|
||||
auto & ctx_dft = spec->ctx_dft;
|
||||
auto & smpl = spec->smpl;
|
||||
auto & prompt_dft = spec->prompt_dft;
|
||||
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
@@ -200,13 +545,16 @@ llama_tokens common_speculative_gen_draft(
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft;
|
||||
|
||||
llama_tokens prompt_tgt_draft_model;
|
||||
if (!spec->vocab_dft_compatible) {
|
||||
llama_tokens prompt_cnv;
|
||||
if (!spec->vocab_cmpt) {
|
||||
std::string text;
|
||||
text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true);
|
||||
|
||||
text = common_detokenize(ctx_tgt, prompt_tgt, true);
|
||||
text = replace_to_dft(spec, text);
|
||||
|
||||
LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
|
||||
prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true);
|
||||
|
||||
prompt_cnv = common_tokenize(ctx_dft, text, false, true);
|
||||
|
||||
// convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
|
||||
const auto * model_tgt = llama_get_model(ctx_tgt);
|
||||
@@ -214,6 +562,7 @@ llama_tokens common_speculative_gen_draft(
|
||||
|
||||
int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
|
||||
GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
|
||||
|
||||
text.resize(-n_chars);
|
||||
llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
|
||||
text = replace_to_dft(spec, text);
|
||||
@@ -221,23 +570,22 @@ llama_tokens common_speculative_gen_draft(
|
||||
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
|
||||
id_last = common_tokenize(ctx_dft, text, false, true)[0];
|
||||
}
|
||||
// prompt_tgt's tokens will always be compatible with ctx_dft
|
||||
const llama_tokens &prompt_tgt =
|
||||
spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model;
|
||||
|
||||
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
|
||||
const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv;
|
||||
|
||||
const int i_start = std::max<int>(0, (int) prompt_cur.size() - n_ctx);
|
||||
|
||||
// reuse as much as possible from the old draft context
|
||||
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
||||
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
|
||||
int cur = 0;
|
||||
while (i_start + cur < (int) prompt_tgt.size() &&
|
||||
while (i_start + cur < (int) prompt_cur.size() &&
|
||||
i + cur < (int) prompt_dft.size() &&
|
||||
prompt_tgt[i_start + cur] == prompt_dft[i + cur]) {
|
||||
prompt_cur[i_start + cur] == prompt_dft[i + cur]) {
|
||||
cur++;
|
||||
}
|
||||
|
||||
if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
|
||||
if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) {
|
||||
reuse_i = i;
|
||||
reuse_n = cur;
|
||||
}
|
||||
@@ -282,11 +630,11 @@ llama_tokens common_speculative_gen_draft(
|
||||
// prepare a batch to evaluate any new tokens in the prompt
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
|
||||
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
|
||||
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
|
||||
for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) {
|
||||
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]);
|
||||
common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false);
|
||||
|
||||
prompt_dft.push_back(prompt_tgt[i]);
|
||||
prompt_dft.push_back(prompt_cur[i]);
|
||||
}
|
||||
|
||||
// we should rarely end-up here during normal decoding
|
||||
@@ -348,7 +696,7 @@ llama_tokens common_speculative_gen_draft(
|
||||
prompt_dft.push_back(id);
|
||||
}
|
||||
|
||||
if (!spec->vocab_dft_compatible) {
|
||||
if (!spec->vocab_cmpt) {
|
||||
std::string detokenized = common_detokenize(ctx_dft, result, true);
|
||||
detokenized = replace_to_tgt(spec, detokenized);
|
||||
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
|
||||
@@ -357,5 +705,211 @@ llama_tokens common_speculative_gen_draft(
|
||||
result.resize(params.n_draft);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform speculative generation using a 3-tier n-gram cache.
|
||||
*
|
||||
* @param state Current state of this implementation
|
||||
* @param tokens Token history to search in
|
||||
* @param sampled Last sampled token
|
||||
* @return Vector of draft tokens, empty if draft is found
|
||||
*/
|
||||
static llama_tokens common_speculative_gen_ngram_cache(
|
||||
common_speculative_state_ngram_cache & state,
|
||||
const llama_tokens & tokens, llama_token sampled) {
|
||||
if (state.cache_size < tokens.size() + 1) {
|
||||
llama_tokens tokens_new;
|
||||
tokens_new.reserve(tokens.size() + 1 - state.cache_size);
|
||||
for (size_t j = state.cache_size; j < tokens.size(); ++j) {
|
||||
tokens_new.push_back(tokens[j]);
|
||||
}
|
||||
tokens_new.push_back(sampled); // add the last token
|
||||
|
||||
// Update context ngram cache with new tokens:
|
||||
common_ngram_cache_update(state.ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
|
||||
tokens_new, tokens_new.size(), false);
|
||||
state.cache_size = tokens.size() + 1;
|
||||
}
|
||||
|
||||
llama_tokens inp;
|
||||
inp.reserve(tokens.size() + 1);
|
||||
for (size_t j = 0; j < tokens.size(); ++j) {
|
||||
inp.push_back(tokens[j]);
|
||||
}
|
||||
inp.push_back(sampled);
|
||||
|
||||
llama_tokens draft;
|
||||
draft.push_back(sampled);
|
||||
|
||||
common_ngram_cache_draft(inp, draft, state.n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
|
||||
state.ngram_cache_context,
|
||||
state.ngram_cache_dynamic,
|
||||
state.ngram_cache_static);
|
||||
|
||||
if (draft.size() > 0) {
|
||||
// delete first token in draft (which is the sampled token)
|
||||
draft.erase(draft.begin());
|
||||
}
|
||||
|
||||
return draft;
|
||||
}
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt_tgt, // specified in target model vocab
|
||||
llama_token id_last) {
|
||||
llama_tokens result = {};
|
||||
|
||||
spec->curr_impl = nullptr; // reset current implementation
|
||||
|
||||
// TODO: avoid dynamic casts
|
||||
for (auto & impl : spec->impls) {
|
||||
impl->drafts_call_count++;
|
||||
const int64_t t_start_us = impl->gen_perf ? ggml_time_us() : 0;
|
||||
|
||||
switch (impl->type) {
|
||||
case COMMON_SPECULATIVE_TYPE_NONE:
|
||||
{
|
||||
} break;
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT:
|
||||
{
|
||||
// Create a draft using a draft model.
|
||||
auto * draft_impl = dynamic_cast<struct common_speculative_state_draft *>(impl.get());
|
||||
if (draft_impl) {
|
||||
result = common_speculative_use_draft_model(draft_impl, params, prompt_tgt, id_last);
|
||||
} else {
|
||||
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
|
||||
}
|
||||
} break;
|
||||
case COMMON_SPECULATIVE_TYPE_EAGLE3:
|
||||
{
|
||||
// Work in progress: https://github.com/ggml-org/llama.cpp/pull/18039
|
||||
} break;
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
|
||||
{
|
||||
// Use common_ngram_map_draft to generate a draft from the current context.
|
||||
auto * state = dynamic_cast<struct common_speculative_state_ngram_simple *>(impl.get());
|
||||
if (state) {
|
||||
result = common_ngram_simple_draft(state->state, prompt_tgt, id_last);
|
||||
} else {
|
||||
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
|
||||
}
|
||||
} break;
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
|
||||
{
|
||||
// Use common_ngram_map_draft to generate a draft from the current context.
|
||||
auto * state = dynamic_cast<common_speculative_state_ngram_map_k *>(impl.get());
|
||||
if (state) {
|
||||
common_ngram_map_draft(state->map, prompt_tgt, id_last, result);
|
||||
} else {
|
||||
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
|
||||
}
|
||||
} break;
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V:
|
||||
{
|
||||
// Use common_ngram_map_draft to generate a draft from the current context.
|
||||
auto * state = dynamic_cast<common_speculative_state_ngram_map_k *>(impl.get());
|
||||
if (state) {
|
||||
common_ngram_map_draft(state->map, prompt_tgt, id_last, result);
|
||||
} else {
|
||||
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
|
||||
}
|
||||
} break;
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_MOD:
|
||||
{
|
||||
auto * state = dynamic_cast<common_speculative_state_ngram_mod *>(impl.get());
|
||||
if (state) {
|
||||
common_ngram_mod_draft(state->mod, prompt_tgt, id_last, result);
|
||||
} else {
|
||||
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
|
||||
}
|
||||
} break;
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE:
|
||||
{
|
||||
auto * state = dynamic_cast<common_speculative_state_ngram_cache *>(impl.get());
|
||||
if (state) {
|
||||
result = common_speculative_gen_ngram_cache(*state, prompt_tgt, id_last);
|
||||
} else {
|
||||
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
|
||||
}
|
||||
} break;
|
||||
case COMMON_SPECULATIVE_TYPE_COUNT:
|
||||
{
|
||||
GGML_ABORT("invalid speculative type COUNT");
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_now_us = impl->gen_perf ? ggml_time_us() : 0;
|
||||
impl->gen_duration_us += t_now_us - t_start_us; // accumulate duration for this implementation
|
||||
|
||||
if (!result.empty()) {
|
||||
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
|
||||
common_speculative_type_to_str(impl.get()->type).c_str(),
|
||||
prompt_tgt.size(),
|
||||
impl.get()->drafts_call_count, result.size());
|
||||
spec->curr_impl = impl.get(); // set current implementation for stats
|
||||
impl->drafts_generated_count++;
|
||||
impl->drafts_generated_tokens += result.size();
|
||||
|
||||
break; // We have a draft, so break out of the loop and return it.
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_speculative_accept(struct common_speculative * spec, uint16_t n_accepted) {
|
||||
if (n_accepted == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
common_speculative_state * impl = spec->curr_impl;
|
||||
|
||||
GGML_ASSERT(impl);
|
||||
|
||||
if (n_accepted > 0) {
|
||||
impl->drafts_accepted_count++;
|
||||
impl->drafts_accepted_tokens += n_accepted;
|
||||
}
|
||||
|
||||
LOG_WRN("XXXXXXXXXXXXX n_accepted = %d\n", n_accepted);
|
||||
|
||||
if (impl->type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K ||
|
||||
impl->type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V) {
|
||||
|
||||
// TODO: add common_speculative_state::accept() to base class and remove this dynamic cast
|
||||
auto * state = dynamic_cast<struct common_speculative_state_ngram_map_k *>(impl);
|
||||
if (state) {
|
||||
common_ngram_map_accept(state->map, n_accepted);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void common_speculative_print_stats(const struct common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto & impl : spec->impls) {
|
||||
std::string str_perf;
|
||||
if (impl->gen_perf) {
|
||||
std::ostringstream oss;
|
||||
oss << std::fixed << std::setprecision(3) << impl->gen_duration_us / 1000.0;
|
||||
str_perf = ", dur = " + oss.str() + " ms";
|
||||
} else {
|
||||
str_perf = "";
|
||||
}
|
||||
|
||||
LOG_INF("statistics %s: #calls = %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
|
||||
common_speculative_type_to_str(impl->type).c_str(),
|
||||
impl->drafts_call_count,
|
||||
impl->drafts_generated_count,
|
||||
impl->drafts_accepted_count,
|
||||
impl->drafts_generated_tokens,
|
||||
impl->drafts_accepted_tokens,
|
||||
str_perf.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,29 +7,36 @@ struct common_speculative;
|
||||
|
||||
struct common_speculative_params {
|
||||
int n_draft = 16; // max drafted tokens
|
||||
int n_reuse = 256;
|
||||
|
||||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
};
|
||||
|
||||
// comma separated list of all types
|
||||
std::string common_speculative_type_name_str();
|
||||
|
||||
// convert string to type
|
||||
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
|
||||
|
||||
// convert type to string
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct llama_context * ctx_tgt,
|
||||
struct llama_context * ctx_dft
|
||||
);
|
||||
const struct common_params_speculative & params,
|
||||
struct llama_context * ctx_tgt,
|
||||
const struct llama_context_params & cparams_dft,
|
||||
struct llama_model * model_dft);
|
||||
|
||||
void common_speculative_free(struct common_speculative * spec);
|
||||
|
||||
bool common_speculative_are_compatible(
|
||||
const struct llama_context * ctx_tgt,
|
||||
const struct llama_context * ctx_dft);
|
||||
|
||||
void common_speculative_add_replacement_tgt_dft(
|
||||
struct common_speculative * spec,
|
||||
const char *source, const char *dest);
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
|
||||
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(struct common_speculative * spec, uint16_t n_accepted);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const struct common_speculative * spec);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -170,6 +170,7 @@ pre_computed_hashes = [
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
# jina-v2-de variants
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
- [CMake Options](#cmake-options)
|
||||
- [Android](#android)
|
||||
- [Windows 11 Arm64](#windows-11-arm64)
|
||||
- [Linux](#Linux)
|
||||
- [Known Issue](#known-issues)
|
||||
- [TODO](#todo)
|
||||
|
||||
|
||||
@@ -248,6 +248,14 @@ You may set the [cuda environmental variables](https://docs.nvidia.com/cuda/cuda
|
||||
CUDA_VISIBLE_DEVICES="-0" ./build/bin/llama-server --model /srv/models/llama.gguf
|
||||
```
|
||||
|
||||
#### CUDA_SCALE_LAUNCH_QUEUES
|
||||
|
||||
The environment variable [`CUDA_SCALE_LAUNCH_QUEUES`](https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/environment-variables.html#cuda-scale-launch-queues) controls the size of CUDA's command buffer, which determines how many GPU operations can be queued before the CPU must wait for the GPU to catch up. A larger buffer reduces CPU-side stalls and allows more work to be queued on a GPU.
|
||||
|
||||
**Default behavior:** llama.cpp automatically sets `CUDA_SCALE_LAUNCH_QUEUES=4x`, which increases the CUDA command buffer to 4 times its default size. This optimization is particularly beneficial for **Multi-GPU setups with pipeline parallelism**, where it significantly improves prompt processing throughput by allowing more operations to be enqueued across GPUs.
|
||||
|
||||
See PR [#19042](https://github.com/ggml-org/llama.cpp/pull/19042) for performance benchmarks and technical details.
|
||||
|
||||
### Unified Memory
|
||||
|
||||
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. In Windows this setting is available in the NVIDIA control panel as `System Memory Fallback`.
|
||||
|
||||
120
docs/speculative.md
Normal file
120
docs/speculative.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# Speculative Decoding
|
||||
|
||||
llama.cpp supports speculative decoding, a technique that can significantly accelerate token generation by predicting multiple tokens ahead of the main model.
|
||||
|
||||
[Speculative decoding](https://en.wikipedia.org/wiki/Transformer_(deep_learning)#Speculative_decoding) leverages the fact that computing n tokens in a batch (as in prompt processing) is more efficient than computing n sequentially (as in response generation). By generating draft tokens quickly and then verifying them with the target model in a single batch, this approach can achieve substantial speedups when the draft predictions are frequently correct.
|
||||
|
||||
## Implementations
|
||||
|
||||
The `llama-server` application supports several implementations of speculative decoding:
|
||||
|
||||
### Draft Model (`draft`)
|
||||
|
||||
A much smaller model (called the _draft model_) generates drafts.
|
||||
A draft model is the most used approach in speculative decoding.
|
||||
|
||||
### n-gram Cache (`ngram-cache`)
|
||||
|
||||
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
|
||||
A draft is computed using probabilities derived from these statistics. External statistics can also be loaded from files for improved accuracy.
|
||||
|
||||
See:
|
||||
|
||||
- #5479, #6828, #6848
|
||||
|
||||
### n-gram Map (`ngram-simple`, `ngram-map-*`)
|
||||
|
||||
These implementations search the token history for patterns and use matching sequences as draft candidates.
|
||||
They require no additional model but rely on patterns that have already appeared in the generated text.
|
||||
An example to use this approach can be the rewriting of source code by a LLM.
|
||||
|
||||
#### n-gram Map (`ngram-simple`)
|
||||
|
||||
This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
|
||||
|
||||
#### n-gram Map Key (`ngram-map-k`)
|
||||
|
||||
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`) before generating drafts.
|
||||
|
||||
The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
#### n-gram Map Key-4-Values (`ngram-map-k4v`)
|
||||
|
||||
This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft.
|
||||
|
||||
The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
**Example:** Server options to be used if there are a lot of longer repetitions.
|
||||
```bash
|
||||
llama-server [...] --spec-draftless ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2
|
||||
```
|
||||
|
||||
|
||||
## Command-Line Options (draftless)
|
||||
|
||||
If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
|
||||
|
||||
```
|
||||
--spec-draftless [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]
|
||||
type of speculative decoding to use when no draft model is provided
|
||||
(default: none)
|
||||
--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
|
||||
of lookup n-gram (default: 12)
|
||||
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
|
||||
of draft m-gram (default: 48)
|
||||
--spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding
|
||||
(default: 1)
|
||||
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### `--spec-draftless TYPE`
|
||||
|
||||
Specifies a type of speculative decoding without draft model.
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `none` | No speculative decoding (default) |
|
||||
| `ngram-cache` | Use n-gram cache lookup |
|
||||
| `ngram-simple` | Use simple n-gram pattern matching |
|
||||
| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys |
|
||||
| `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) |
|
||||
|
||||
**Example:** Server-instance used to refactor source code.
|
||||
```bash
|
||||
./llama-server [...] --spec-draftless ngram-simple
|
||||
```
|
||||
|
||||
### `--spec-ngram-size-n N`
|
||||
|
||||
Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
|
||||
The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
|
||||
|
||||
### `--spec-ngram-size-m M`
|
||||
|
||||
Sets the size M of the draft m-gram for n-gram map based speculative decoding.
|
||||
The m-gram size determines how many tokens to draft when a match is found.
|
||||
Larger values can provide more speedup but may reduce acceptance rate.
|
||||
|
||||
### `--spec-ngram-check-rate R`
|
||||
|
||||
This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token).
|
||||
|
||||
### `--spec-ngram-min-hits H`
|
||||
|
||||
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
|
||||
|
||||
## Statistics
|
||||
Each speculative decoding implementation prints statistics.
|
||||
|
||||
```
|
||||
draft acceptance rate = 0.57576 ( 171 accepted / 297 generated)
|
||||
statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tokens = 187, #acc tokens = 73
|
||||
statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98
|
||||
```
|
||||
|
||||
- `#calls`: number of calls of this implementations
|
||||
- `#gen drafts`: number of drafts generated by this implementation
|
||||
- `#acc drafts`: number of drafts accepted (partially) by the main model
|
||||
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
|
||||
- `#acc tokens`: number of tokens accepted by the main model
|
||||
|
||||
@@ -32,9 +32,9 @@ int main(int argc, char ** argv){
|
||||
|
||||
common_ngram_cache ngram_cache;
|
||||
common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
|
||||
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
|
||||
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str());
|
||||
|
||||
common_ngram_cache_save(ngram_cache, params.lookup_cache_static);
|
||||
common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -46,18 +46,18 @@ int main(int argc, char ** argv){
|
||||
{
|
||||
const int64_t t_start_draft_us = ggml_time_us();
|
||||
|
||||
if (!params.lookup_cache_static.empty()) {
|
||||
if (!params.speculative.lookup_cache_static.empty()) {
|
||||
try {
|
||||
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
|
||||
} catch (std::ifstream::failure const &) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.lookup_cache_dynamic.empty()) {
|
||||
if (!params.speculative.lookup_cache_dynamic.empty()) {
|
||||
try {
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
|
||||
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
||||
}
|
||||
|
||||
|
||||
@@ -51,18 +51,18 @@ int main(int argc, char ** argv){
|
||||
const int64_t t_start_draft_us = ggml_time_us();
|
||||
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
|
||||
|
||||
if (!params.lookup_cache_static.empty()) {
|
||||
if (!params.speculative.lookup_cache_static.empty()) {
|
||||
try {
|
||||
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
|
||||
} catch (std::ifstream::failure const &) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.lookup_cache_dynamic.empty()) {
|
||||
if (!params.speculative.lookup_cache_dynamic.empty()) {
|
||||
try {
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
|
||||
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
||||
}
|
||||
|
||||
@@ -210,7 +210,7 @@ int main(int argc, char ** argv){
|
||||
|
||||
// Update dynamic ngram cache with context ngram cache and save it to disk:
|
||||
common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
|
||||
common_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic);
|
||||
common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic);
|
||||
|
||||
LOG("\n\n");
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ set -e
|
||||
|
||||
# First try command line argument, then environment variable, then file
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
@@ -13,6 +14,10 @@ if [ -z "$CONVERTED_MODEL" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cmake --build ../../build --target llama-debug -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits
|
||||
cmake --build ${BUILD_DIR} --target llama-debug -j8
|
||||
|
||||
${BUILD_DIR}/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits
|
||||
|
||||
@@ -5,11 +5,16 @@ set -e
|
||||
# First try command line argument, then environment variable, then file
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
|
||||
BUILD_DIR="${3:-"$BUILD_DIR"}"
|
||||
|
||||
if [ -z "$MODEL_TESTING_PROMPT"]; then
|
||||
if [ -z "$MODEL_TESTING_PROMPT" ]; then
|
||||
MODEL_TESTING_PROMPT="Hello, my name is"
|
||||
fi
|
||||
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
echo "Error: Model path must be provided either as:" >&2
|
||||
@@ -21,6 +26,6 @@ fi
|
||||
echo $CONVERTED_MODEL
|
||||
echo $MODEL_TESTING_PROMPT
|
||||
|
||||
cmake --build ../../build --target llama-debug -j8
|
||||
cmake --build ${BUILD_DIR} --target llama-debug -j8
|
||||
|
||||
../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits
|
||||
${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits
|
||||
|
||||
@@ -28,6 +28,7 @@ done
|
||||
|
||||
# First try command line argument, then environment variable
|
||||
CONVERTED_MODEL="${CONVERTED_MODEL:-"$CONVERTED_EMBEDDING_MODEL"}"
|
||||
BUILD_DIR="${BUILD_DIR:-"../../build"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
@@ -50,5 +51,5 @@ fi
|
||||
|
||||
echo $CONVERTED_MODEL
|
||||
|
||||
cmake --build ../../build --target llama-debug -j8
|
||||
../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE
|
||||
cmake --build ${BUILD_DIR} --target llama-debug -j8
|
||||
${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
set -e
|
||||
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
@@ -25,9 +26,13 @@ mkdir -p ppl
|
||||
OUTPUTFILE="ppl/$(basename $CONVERTED_MODEL).kld"
|
||||
echo "Model: $CONVERTED_MODEL"
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
../.././build/bin/llama-perplexity -m $CONVERTED_MODEL \
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $CONVERTED_MODEL \
|
||||
-f ppl/wikitext-2-raw/wiki.test.raw \
|
||||
--kl-divergence-base $OUTPUTFILE
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
set -e
|
||||
|
||||
QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
if [ -z "$QUANTIZED_MODEL" ]; then
|
||||
echo "Error: Model path must be provided either as:" >&2
|
||||
@@ -20,8 +21,12 @@ if [ ! -d "ppl/wikitext-2-raw" ]; then
|
||||
popd
|
||||
fi
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
set -e
|
||||
|
||||
QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}"
|
||||
LOGITS_FILE="${1:-"$LOGITS_FILE"}"
|
||||
LOGITS_FILE="${2:-"$LOGITS_FILE"}"
|
||||
BUILD_DIR="${3:-"$BUILD_DIR"}"
|
||||
|
||||
if [ -z "$QUANTIZED_MODEL" ]; then
|
||||
echo "Error: Model path must be provided either as:" >&2
|
||||
@@ -18,11 +19,15 @@ if [ ! -f ${LOGITS_FILE} ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
echo "Model: $QUANTIZED_MODEL"
|
||||
echo "Data file: $LOGITS_FILE"
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL \
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL \
|
||||
--kl-divergence-base $LOGITS_FILE \
|
||||
--kl-divergence
|
||||
|
||||
@@ -6,6 +6,7 @@ CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
QUANTIZED_TYPE="${2:-"$QUANTIZED_TYPE"}"
|
||||
TOKEN_EMBD_TYPE="${3:-"${TOKEN_EMBD_TYPE}"}"
|
||||
OUTPUT_TYPE="${4:-"${OUTPUT_TYPE}"}"
|
||||
BUILD_DIR="${5:-"$BUILD_DIR"}"
|
||||
QUANTIZED_MODEL=$CONVERTED_MODEL
|
||||
|
||||
# Final check if we have a model path
|
||||
@@ -33,12 +34,16 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cmake --build ../../build --target llama-quantize -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
cmake --build $BUILD_DIR --target llama-quantize -j8
|
||||
|
||||
echo $TOKEN_EMBD_TYPE
|
||||
echo $OUTPUT_TYPE
|
||||
|
||||
CMD_ARGS=("../../build/bin/llama-quantize")
|
||||
CMD_ARGS=("${BUILD_DIR}/bin/llama-quantize")
|
||||
[[ -n "$TOKEN_EMBD_TYPE" ]] && CMD_ARGS+=("--token-embedding-type" "$TOKEN_EMBD_TYPE")
|
||||
[[ -n "$OUTPUT_TYPE" ]] && CMD_ARGS+=("--output-tensor-type" "$OUTPUT_TYPE")
|
||||
CMD_ARGS+=("$CONVERTED_MODEL" "$QUANTIZED_MODEL" "$QUANTIZED_TYPE")
|
||||
|
||||
@@ -4,6 +4,7 @@ set -e
|
||||
#
|
||||
# First try command line argument, then environment variable, then file
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
@@ -13,10 +14,14 @@ if [ -z "$CONVERTED_MODEL" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
echo $CONVERTED_MODEL
|
||||
|
||||
cmake --build ../../build --target llama-server
|
||||
cmake --build $BUILD_DIR --target llama-server
|
||||
|
||||
../../build/bin/llama-server -m $CONVERTED_MODEL \
|
||||
${BUILD_DIR}/bin/llama-server -m $CONVERTED_MODEL \
|
||||
--embedding \
|
||||
--pooling none
|
||||
|
||||
@@ -34,10 +34,9 @@ int main(int argc, char ** argv) {
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model_tgt = NULL;
|
||||
//llama_model * model_dft = NULL;
|
||||
llama_model * model_dft = NULL;
|
||||
|
||||
llama_context * ctx_tgt = NULL;
|
||||
llama_context * ctx_dft = NULL;
|
||||
|
||||
// load the target model
|
||||
auto llama_init_tgt = common_init_from_params(params);
|
||||
@@ -63,12 +62,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
auto llama_init_dft = common_init_from_params(params);
|
||||
|
||||
//model_dft = llama_init_dft->model();
|
||||
ctx_dft = llama_init_dft->context();
|
||||
|
||||
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
|
||||
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
|
||||
}
|
||||
model_dft = llama_init_dft->model();
|
||||
|
||||
// Tokenize the prompt
|
||||
std::vector<llama_token> inp;
|
||||
@@ -129,13 +123,9 @@ int main(int argc, char ** argv) {
|
||||
// init the speculator
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft;
|
||||
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
|
||||
params_spec.p_min = p_min;
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
|
||||
for (auto &pair : params.speculative.replacements) {
|
||||
common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
|
||||
}
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt, common_context_params_to_llama(params), model_dft);
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
||||
|
||||
@@ -249,8 +239,6 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("\n");
|
||||
LOG_INF("draft:\n\n");
|
||||
|
||||
llama_perf_context_print(ctx_dft);
|
||||
|
||||
LOG_INF("\n");
|
||||
LOG_INF("target:\n\n");
|
||||
common_perf_print(ctx_tgt, smpl);
|
||||
|
||||
@@ -630,10 +630,11 @@ extern "C" {
|
||||
|
||||
// this tensor...
|
||||
enum ggml_tensor_flag {
|
||||
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
|
||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
|
||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||
GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed
|
||||
};
|
||||
|
||||
enum ggml_tri_type {
|
||||
@@ -2577,11 +2578,42 @@ extern "C" {
|
||||
struct ggml_tensor * grad,
|
||||
struct ggml_tensor * sgd_params); // alpha, weight decay
|
||||
|
||||
// build forward mutiple tensors and select one of them for computing
|
||||
// this is useful for creating graphs that have constant topology but compute different things based on the input
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18550
|
||||
//
|
||||
// automatic differentiation
|
||||
// nodes:
|
||||
// | - build forward into the graph but do not compute
|
||||
// c - build forward into the graph and compute
|
||||
//
|
||||
// | | ... c ... |
|
||||
// | | ... c ... |
|
||||
// | | ... c ... |
|
||||
// [0 1 ... idx ... n-1] <-- ggml_build_forward_select(..., n, idx)
|
||||
// c
|
||||
// c
|
||||
//
|
||||
// example:
|
||||
// struct ggml_tensor * curs[3];
|
||||
//
|
||||
// curs[0] = compute0(...);
|
||||
// curs[1] = compute1(...);
|
||||
// curs[2] = compute2(...);
|
||||
//
|
||||
// int idx = select_branch(some_input);
|
||||
//
|
||||
// struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx);
|
||||
//
|
||||
GGML_API struct ggml_tensor * ggml_build_forward_select(
|
||||
struct ggml_cgraph * cgraph,
|
||||
struct ggml_tensor ** tensors,
|
||||
int n_tensors,
|
||||
int idx);
|
||||
|
||||
GGML_API void ggml_build_forward_expand(
|
||||
struct ggml_cgraph * cgraph,
|
||||
struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_build_backward_expand(
|
||||
struct ggml_context * ctx, // context for gradient computation
|
||||
struct ggml_cgraph * cgraph,
|
||||
@@ -2613,7 +2645,7 @@ extern "C" {
|
||||
GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
|
||||
|
||||
// dump the graph into a file using the dot format
|
||||
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
|
||||
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename);
|
||||
|
||||
// TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
|
||||
typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
|
||||
|
||||
@@ -77,39 +77,23 @@
|
||||
#include "ggml-zendnn.h"
|
||||
#endif
|
||||
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
static std::string path_str(const fs::path & path) {
|
||||
std::string u8path;
|
||||
try {
|
||||
#if defined(__cpp_lib_char8_t)
|
||||
// C++20 and later: u8string() returns std::u8string
|
||||
std::u8string u8str = path.u8string();
|
||||
u8path = std::string(reinterpret_cast<const char*>(u8str.c_str()));
|
||||
const std::u8string u8str = path.u8string();
|
||||
return std::string(reinterpret_cast<const char *>(u8str.data()), u8str.size());
|
||||
#else
|
||||
// C++17: u8string() returns std::string
|
||||
u8path = path.u8string();
|
||||
return path.u8string();
|
||||
#endif
|
||||
} catch (...) {
|
||||
return std::string();
|
||||
}
|
||||
return u8path;
|
||||
}
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
|
||||
@@ -874,9 +874,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
|
||||
}
|
||||
if (sched->debug > 1) {
|
||||
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
|
||||
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
|
||||
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name,
|
||||
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
|
||||
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
|
||||
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0);
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * src = node->src[j];
|
||||
if (src == NULL) {
|
||||
@@ -1922,6 +1922,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set,
|
||||
dst->view_offs = src->view_offs;
|
||||
}
|
||||
dst->op = src->op;
|
||||
dst->flags = src->flags;
|
||||
memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
|
||||
ggml_set_name(dst, src->name);
|
||||
|
||||
|
||||
@@ -226,6 +226,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
ggml_backend_blas_mul_mat(ctx, node);
|
||||
|
||||
@@ -2146,6 +2146,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
|
||||
if (!ok) {
|
||||
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
|
||||
@@ -38,9 +38,10 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -48,9 +49,10 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -70,12 +72,14 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
@@ -94,9 +98,10 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -104,9 +109,10 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -126,9 +132,10 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -136,9 +143,10 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -165,18 +173,20 @@
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -202,9 +212,10 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -212,9 +223,10 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -242,9 +254,10 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -252,9 +265,10 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
||||
@@ -25,9 +25,8 @@
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
|
||||
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
|
||||
int16x8_t * out_mins,
|
||||
int8_t * out_scales) {
|
||||
// Helper for decoding scales and mins of Q4_K and Q5_K block formats
|
||||
static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
@@ -561,7 +560,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
@@ -701,7 +700,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
@@ -786,6 +785,293 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int col_pairs = ncols_interleaved / 2;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 1x8 tile = 2 x 4
|
||||
float32x4_t acc_f32[ncols_interleaved / 4];
|
||||
|
||||
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < ncols_interleaved / 4; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
|
||||
float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||
float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
|
||||
float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
|
||||
float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
|
||||
float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
|
||||
float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d);
|
||||
float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d);
|
||||
|
||||
// 2 sb each iteration
|
||||
int32x4_t acc_lo[col_pairs];
|
||||
int32x4_t acc_hi[col_pairs];
|
||||
|
||||
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
||||
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
||||
int16_t bsums_arr[8];
|
||||
vst1q_s16(bsums_arr, bsums);
|
||||
|
||||
// Load qh once per block and shift after each subblock
|
||||
const uint8_t * qh_base = q5_ptr[b].qh;
|
||||
uint8x16_t qh[col_pairs][4];
|
||||
for (int cp = 0; cp < col_pairs; cp++) {
|
||||
qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
|
||||
qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
|
||||
qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
|
||||
qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
for (int i = 0; i < col_pairs; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
int16x8_t q5sb_scales[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q5sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
||||
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
||||
}
|
||||
|
||||
const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
|
||||
|
||||
// Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
|
||||
const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
|
||||
int8x16_t q8_qs[8];
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
|
||||
}
|
||||
|
||||
// Q5s column pair loop unrolled
|
||||
{
|
||||
// Cols 01
|
||||
uint8x16_t qs_0 = vld1q_u8(qs_base);
|
||||
uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
|
||||
uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
|
||||
uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
|
||||
|
||||
uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
|
||||
uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
|
||||
uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
|
||||
uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
|
||||
uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
|
||||
uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
|
||||
uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
|
||||
uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
|
||||
|
||||
qh[0][0] = vshrq_n_u8(qh[0][0], 2);
|
||||
qh[0][1] = vshrq_n_u8(qh[0][1], 2);
|
||||
qh[0][2] = vshrq_n_u8(qh[0][2], 2);
|
||||
qh[0][3] = vshrq_n_u8(qh[0][3], 2);
|
||||
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
|
||||
// Cols 23
|
||||
qs_0 = vld1q_u8(qs_base + 16);
|
||||
qs_1 = vld1q_u8(qs_base + 80);
|
||||
qs_2 = vld1q_u8(qs_base + 144);
|
||||
qs_3 = vld1q_u8(qs_base + 208);
|
||||
|
||||
hbit_lo_0 = vandq_u8(qh[1][0], mone);
|
||||
hbit_lo_1 = vandq_u8(qh[1][1], mone);
|
||||
hbit_lo_2 = vandq_u8(qh[1][2], mone);
|
||||
hbit_lo_3 = vandq_u8(qh[1][3], mone);
|
||||
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
|
||||
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
|
||||
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
|
||||
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
|
||||
|
||||
qh[1][0] = vshrq_n_u8(qh[1][0], 2);
|
||||
qh[1][1] = vshrq_n_u8(qh[1][1], 2);
|
||||
qh[1][2] = vshrq_n_u8(qh[1][2], 2);
|
||||
qh[1][3] = vshrq_n_u8(qh[1][3], 2);
|
||||
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
|
||||
// Cols 45
|
||||
qs_0 = vld1q_u8(qs_base + 32);
|
||||
qs_1 = vld1q_u8(qs_base + 96);
|
||||
qs_2 = vld1q_u8(qs_base + 160);
|
||||
qs_3 = vld1q_u8(qs_base + 224);
|
||||
|
||||
hbit_lo_0 = vandq_u8(qh[2][0], mone);
|
||||
hbit_lo_1 = vandq_u8(qh[2][1], mone);
|
||||
hbit_lo_2 = vandq_u8(qh[2][2], mone);
|
||||
hbit_lo_3 = vandq_u8(qh[2][3], mone);
|
||||
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
|
||||
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
|
||||
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
|
||||
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
|
||||
|
||||
qh[2][0] = vshrq_n_u8(qh[2][0], 2);
|
||||
qh[2][1] = vshrq_n_u8(qh[2][1], 2);
|
||||
qh[2][2] = vshrq_n_u8(qh[2][2], 2);
|
||||
qh[2][3] = vshrq_n_u8(qh[2][3], 2);
|
||||
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
|
||||
// Cols 45
|
||||
qs_0 = vld1q_u8(qs_base + 48);
|
||||
qs_1 = vld1q_u8(qs_base + 112);
|
||||
qs_2 = vld1q_u8(qs_base + 176);
|
||||
qs_3 = vld1q_u8(qs_base + 240);
|
||||
|
||||
hbit_lo_0 = vandq_u8(qh[3][0], mone);
|
||||
hbit_lo_1 = vandq_u8(qh[3][1], mone);
|
||||
hbit_lo_2 = vandq_u8(qh[3][2], mone);
|
||||
hbit_lo_3 = vandq_u8(qh[3][3], mone);
|
||||
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
|
||||
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
|
||||
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
|
||||
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
|
||||
|
||||
qh[3][0] = vshrq_n_u8(qh[3][0], 2);
|
||||
qh[3][1] = vshrq_n_u8(qh[3][1], 2);
|
||||
qh[3][2] = vshrq_n_u8(qh[3][2], 2);
|
||||
qh[3][3] = vshrq_n_u8(qh[3][3], 2);
|
||||
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
}
|
||||
|
||||
// Prepare bsum vectors for bias computation
|
||||
// Each pair of subblocks share the same bsums
|
||||
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
||||
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
||||
|
||||
// Iterates over a pair of column pairs (4 columns) to use a single 128 register
|
||||
// p = 0 -> 0123 p2 -> 4567
|
||||
for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
|
||||
int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
|
||||
int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
|
||||
int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
|
||||
int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
|
||||
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
|
||||
float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1;
|
||||
|
||||
// 0123 or 4567
|
||||
float32x4_t sumf_0 =
|
||||
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
|
||||
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
|
||||
|
||||
float32x4_t sumf_1 =
|
||||
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
|
||||
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
|
||||
|
||||
// FUSED BIAS: Compute and subtract bias immediately
|
||||
// bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
|
||||
int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
|
||||
bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
|
||||
float32x4_t bias_f32 = vcvtq_f32_s32(bias);
|
||||
acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
|
||||
}
|
||||
} // for sb
|
||||
} // for b
|
||||
|
||||
int base = x * ncols_interleaved;
|
||||
vst1q_f32(s + base, acc_f32[0]);
|
||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||
} // for x
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
@@ -2431,7 +2717,7 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
@@ -2595,7 +2881,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
for (int i = 0; i < 2; i++) {
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
|
||||
}
|
||||
|
||||
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
||||
@@ -2738,6 +3024,252 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
constexpr int col_pairs = ncols_interleaved / 2;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
float32x4_t acc_f32[blocklen];
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// bsums pairs belongs to the same q8_k subblock
|
||||
const int16x8_t bsums[4]{
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
int16_t bsums_arr[4][8];
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
||||
}
|
||||
|
||||
int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
|
||||
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
|
||||
int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
|
||||
for (int i = 0; i < 8; i++) {
|
||||
acc[i] = vdupq_n_s32(0);
|
||||
bias_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
// Load qh once per block and shift after each subblock
|
||||
const uint8_t * qh_base = q5_ptr[b].qh;
|
||||
uint8x16_t qh[col_pairs][4];
|
||||
for (int cp = 0; cp < col_pairs; cp++) {
|
||||
qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
|
||||
qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
|
||||
qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
|
||||
qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int8_t q5sb_scales[2][8];
|
||||
int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
for (int i = 0; i < 2; i++) {
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
|
||||
}
|
||||
|
||||
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
||||
const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
|
||||
|
||||
int8x16_t q8_qs_01[8];
|
||||
int8x16_t q8_qs_23[8];
|
||||
|
||||
// Load 32-byte per row pair, 1 subblock each time
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const int offset = i * 32; // 16 for row 01, 16 for row 23
|
||||
q8_qs_01[i] = vld1q_s8(q8_base + offset);
|
||||
q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
|
||||
}
|
||||
|
||||
const int8x16_t q8s[2][8] = {
|
||||
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
|
||||
q8_qs_01[7] },
|
||||
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
|
||||
q8_qs_23[7] },
|
||||
};
|
||||
|
||||
// Q5s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < col_pairs; cp++) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
sb_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
|
||||
uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
|
||||
uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
|
||||
uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
|
||||
|
||||
// This is the only part of the algorithm that differs with Q4_K
|
||||
// Extract High bits and pack into 5 bit weights
|
||||
uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone);
|
||||
uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
|
||||
qh[cp][0] = vshrq_n_u8(qh[cp][0], 2);
|
||||
// Same as Q4_K, i8mm to dequantize the weights.
|
||||
const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
|
||||
int32x4_t acc_0 = sb_acc[0];
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
|
||||
int32x4_t acc_2 = sb_acc[2];
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
|
||||
const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
|
||||
int32x4_t acc_1 = sb_acc[1];
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
|
||||
int32x4_t acc_3 = sb_acc[3];
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
|
||||
|
||||
// Repeat for the other 3 columns (8..15, 16..23, 24..31)
|
||||
uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
|
||||
uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone);
|
||||
qh[cp][1] = vshrq_n_u8(qh[cp][1], 2);
|
||||
const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
|
||||
const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
|
||||
|
||||
uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
|
||||
uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone);
|
||||
qh[cp][2] = vshrq_n_u8(qh[cp][2], 2);
|
||||
const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
|
||||
const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
|
||||
|
||||
uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone);
|
||||
uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
|
||||
qh[cp][3] = vshrq_n_u8(qh[cp][3], 2);
|
||||
const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
|
||||
sb_acc[0] = acc_0;
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
|
||||
sb_acc[2] = acc_2;
|
||||
|
||||
// Scales[i] corresponds to column i
|
||||
const int scale_offset = cp * 2;
|
||||
const int32_t s0 = q5sb_scales[0][scale_offset];
|
||||
const int32_t s1 = q5sb_scales[0][scale_offset + 1];
|
||||
const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
|
||||
acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
|
||||
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
|
||||
|
||||
const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
|
||||
sb_acc[1] = acc_1;
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
|
||||
sb_acc[3] = acc_3;
|
||||
|
||||
const int32_t s2 = q5sb_scales[1][scale_offset];
|
||||
const int32_t s3 = q5sb_scales[1][scale_offset + 1];
|
||||
const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
|
||||
acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
|
||||
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
|
||||
}
|
||||
|
||||
// Multiply Acc bsum + mins
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
// Each pair of subblocks share the same bsums
|
||||
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
||||
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
|
||||
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
|
||||
|
||||
bias_acc[2 * q8_row] =
|
||||
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * q8_row] =
|
||||
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
||||
bias_acc[2 * q8_row + 1] =
|
||||
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * q8_row + 1] =
|
||||
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
||||
}
|
||||
} // for sb
|
||||
|
||||
// Reorder of i8mm output with bias and output layout
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
|
||||
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
|
||||
}
|
||||
int32x4_t reorder_acc[8] = {
|
||||
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
|
||||
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
|
||||
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
|
||||
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
|
||||
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
|
||||
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
|
||||
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
|
||||
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
|
||||
};
|
||||
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
|
||||
float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
|
||||
const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d);
|
||||
|
||||
float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
|
||||
const float32x4_t scale = vmulq_f32(q5_d, q8_d);
|
||||
|
||||
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
|
||||
acc_f32[2 * i + j] =
|
||||
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
|
||||
}
|
||||
}
|
||||
} // for b
|
||||
|
||||
// With the previous reorder, the tile is already in the correct memory layout.
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q8_0_4x4_q8_0(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "simd-mappings.h"
|
||||
|
||||
#define GGML_FA_TILE_Q 32
|
||||
#define GGML_FA_TILE_KV 16
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
#include <utility>
|
||||
@@ -84,4 +87,9 @@ static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_pa
|
||||
return {ir0, ir1};
|
||||
}
|
||||
|
||||
struct ggml_fa_tile_config {
|
||||
static constexpr size_t Q = GGML_FA_TILE_Q;
|
||||
static constexpr size_t KV = GGML_FA_TILE_KV;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "vec.h"
|
||||
#include "ops.h"
|
||||
#include "ggml.h"
|
||||
#include "common.h"
|
||||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||
@@ -2866,10 +2867,12 @@ struct ggml_cplan ggml_graph_plan(
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const int64_t ne10 = node->src[1]->ne[0]; // DK
|
||||
const int64_t ne20 = node->src[2]->ne[0]; // DV
|
||||
const int64_t DK = node->src[1]->ne[0];
|
||||
const int64_t DV = node->src[2]->ne[0];
|
||||
|
||||
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
|
||||
// Tiled flash attention scratch (tile sizes defined in common.h)
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
|
||||
cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
{
|
||||
@@ -2943,6 +2946,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
|
||||
if (state->ith == 0 && cplan->abort_callback &&
|
||||
|
||||
@@ -1797,10 +1797,27 @@ class tinyBLAS_Q0_AVX {
|
||||
} \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
struct mma_instr;
|
||||
|
||||
template<>
|
||||
struct mma_instr<ggml_bf16_t> {
|
||||
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
|
||||
__builtin_mma_xvbf16ger2pp(acc, a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct mma_instr<ggml_fp16_t> {
|
||||
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
|
||||
__builtin_mma_xvf16ger2pp(acc, a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TA, typename TB, typename TC>
|
||||
class tinyBLAS_BF16_PPC {
|
||||
class tinyBLAS_HP16_PPC {
|
||||
public:
|
||||
tinyBLAS_BF16_PPC(int64_t k,
|
||||
tinyBLAS_HP16_PPC(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const TB *B, int64_t ldb,
|
||||
TC *C, int64_t ldc,
|
||||
@@ -2118,8 +2135,8 @@ class tinyBLAS_BF16_PPC {
|
||||
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
|
||||
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
|
||||
for (int x = 0; x < 4; x++) {
|
||||
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
}
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
@@ -2135,8 +2152,8 @@ class tinyBLAS_BF16_PPC {
|
||||
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
|
||||
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
||||
for (int x = 0; x < 4; x++) {
|
||||
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
}
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
@@ -2155,10 +2172,10 @@ class tinyBLAS_BF16_PPC {
|
||||
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
|
||||
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
|
||||
for (int x = 0; x < 4; x++) {
|
||||
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2189,7 +2206,7 @@ class tinyBLAS_BF16_PPC {
|
||||
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
|
||||
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
|
||||
for (int x = 0; x<2; x++) {
|
||||
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
}
|
||||
}
|
||||
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||
@@ -2224,8 +2241,8 @@ class tinyBLAS_BF16_PPC {
|
||||
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
|
||||
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
|
||||
for (int x = 0; x<4; x++) {
|
||||
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
}
|
||||
}
|
||||
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||
@@ -3418,16 +3435,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||
return tb.matmul(m, n);
|
||||
}
|
||||
#elif defined(__MMA__)
|
||||
if ((k % 8))
|
||||
return false;
|
||||
if(Btype == GGML_TYPE_BF16) {
|
||||
tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
|
||||
(const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
params->ith, params->nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
if (k % 8) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Btype == GGML_TYPE_BF16) {
|
||||
tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
|
||||
(const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
params->ith, params->nth };
|
||||
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
}
|
||||
#elif defined(__riscv_zvfbfwma)
|
||||
#if LMUL == 1
|
||||
@@ -3516,6 +3536,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||
#endif
|
||||
return tb.matmul(m, n);
|
||||
}
|
||||
#elif defined(__MMA__)
|
||||
if (k % 8) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Btype == GGML_TYPE_F16) {
|
||||
tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
|
||||
(const ggml_fp16_t *)A, lda,
|
||||
(const ggml_fp16_t *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
params->ith, params->nth };
|
||||
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -8164,6 +8164,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||
// online softmax / attention
|
||||
// loop over n_kv and n_head_kv
|
||||
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
|
||||
if (mv == -INFINITY) {
|
||||
@@ -8271,6 +8272,280 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
int ir0, int ir1) {
|
||||
const ggml_tensor * q = dst->src[0];
|
||||
const ggml_tensor * k = dst->src[1];
|
||||
const ggml_tensor * v = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int64_t DK = nek0;
|
||||
const int64_t DV = nev0;
|
||||
const int64_t N = neq1;
|
||||
|
||||
GGML_ASSERT(ne0 == DV);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
|
||||
// input tensor rows must be contiguous
|
||||
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
||||
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
||||
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
||||
|
||||
GGML_ASSERT(neq0 == DK);
|
||||
GGML_ASSERT(nek0 == DK);
|
||||
GGML_ASSERT(nev0 == DV);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
GGML_ASSERT(k->type == v->type);
|
||||
const ggml_type kv_type = k->type;
|
||||
|
||||
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
|
||||
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
|
||||
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
|
||||
const size_t kv_type_size = ggml_type_size(kv_type);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t rk2 = neq2/nek2;
|
||||
const int64_t rk3 = neq3/nek3;
|
||||
|
||||
const int64_t rv2 = neq2/nev2;
|
||||
const int64_t rv3 = neq3/nev3;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
int ith = params->ith;
|
||||
|
||||
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
|
||||
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
|
||||
|
||||
int ir = ir0;
|
||||
while (ir < ir1) {
|
||||
// q indices for the start of this tile
|
||||
const int iq3 = ir/(neq2*neq1);
|
||||
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||
|
||||
// Number of valid rows in this tile:
|
||||
// - limited by tile size (Q_TILE_SZ)
|
||||
// - limited by chunk boundary (ir1 - ir)
|
||||
// - limited by head boundary (neq1 - iq1) to avoid crossing into next head
|
||||
const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
|
||||
GGML_ASSERT(tile_rows > 0);
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||
|
||||
float S[Q_TILE_SZ];
|
||||
float M[Q_TILE_SZ];
|
||||
|
||||
for (int i = 0 ; i < Q_TILE_SZ; ++i) {
|
||||
S[i] = 0.;
|
||||
M[i] = -INFINITY;
|
||||
}
|
||||
|
||||
// Per-thread scratch layout:
|
||||
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
|
||||
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
|
||||
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
|
||||
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
|
||||
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
|
||||
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
|
||||
|
||||
void * Q_q = base;
|
||||
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
|
||||
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
|
||||
|
||||
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
|
||||
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
||||
|
||||
// k indices
|
||||
const int ik3 = iq3 / rk3;
|
||||
const int ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
|
||||
}
|
||||
// Zero-pad remaining rows
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
|
||||
}
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
||||
|
||||
// skip the tile entirely if all the masks are -inf
|
||||
if (mask) {
|
||||
bool can_skip = true;
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
|
||||
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
|
||||
can_skip = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (can_skip) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
float s;
|
||||
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
|
||||
KQ[tq * KV_TILE_SZ + tk] = s * scale;
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
|
||||
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
|
||||
}
|
||||
|
||||
bool skip[Q_TILE_SZ] = {};
|
||||
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
float * kq_row = KQ + tq * KV_TILE_SZ;
|
||||
|
||||
float tile_max;
|
||||
ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
|
||||
|
||||
if (tile_max == -INFINITY) {
|
||||
skip[tq] = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
const float Mold = M[tq];
|
||||
const float Mnew = fmaxf(Mold, tile_max);
|
||||
|
||||
if (Mnew > Mold) {
|
||||
const float ms = expf(Mold - Mnew);
|
||||
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
||||
S[tq] *= ms;
|
||||
}
|
||||
M[tq] = Mnew;
|
||||
|
||||
|
||||
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
|
||||
}
|
||||
|
||||
// Convert V tile to F32 first (if F16), then do MAD
|
||||
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
|
||||
// TODO: on ARM, native f16 should be faster
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
|
||||
}
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sinks (apply only to valid rows in the tile)
|
||||
if (sinks) {
|
||||
const float s = ((float *)((char *) sinks->data))[h];
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M[tq]) {
|
||||
ms = expf(M[tq] - s);
|
||||
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
||||
} else {
|
||||
vs = expf(s - M[tq]);
|
||||
}
|
||||
|
||||
S[tq] = S[tq] * ms + vs;
|
||||
}
|
||||
}
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
// V /= S
|
||||
const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
|
||||
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
|
||||
|
||||
// dst indices
|
||||
const int i1 = iq1 + tq;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// permute(0, 2, 1, 3)
|
||||
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
|
||||
}
|
||||
|
||||
ir += tile_rows;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
@@ -8343,6 +8618,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
// The number of elements in each chunk
|
||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||
|
||||
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
|
||||
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
|
||||
kv_is_f32_or_f16 &&
|
||||
k->type == v->type &&
|
||||
nek1 % KV_TILE_SZ == 0 &&
|
||||
neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
|
||||
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
@@ -8350,7 +8634,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const int64_t ir0 = dr * current_chunk;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
||||
if (use_tiled) {
|
||||
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
|
||||
} else {
|
||||
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
||||
}
|
||||
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
}
|
||||
|
||||
@@ -474,15 +474,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
@@ -616,6 +609,100 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
@@ -1212,6 +1299,108 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][8];
|
||||
float sum_minf[4][8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
@@ -1622,7 +1811,95 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in
|
||||
out.scales[i] = in[src1].scales[src2];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) {
|
||||
block_q5_Kx8 out;
|
||||
//Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
|
||||
}
|
||||
|
||||
const int end = QK_K * 4 / blck_size_interleave;
|
||||
|
||||
// Interleave Q5_K quants by taking 8 bytes at a time
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
||||
}
|
||||
|
||||
// Repeat for low bits 8 bytes at a time as well, since
|
||||
// the high bits are interleaved in Q5_K and the index is
|
||||
// qh_idx = (qs_idx % 32);
|
||||
// qh_val = qh[qh_idx] >> (qs_idx / 32);
|
||||
for (int i = 0; i < end / 4; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
|
||||
}
|
||||
|
||||
// The below logic is copied over from Q4_K
|
||||
// The point is to unpack all the scales and mins for each sub block every time we load 12 bytes.
|
||||
// Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
|
||||
// The output Q5_Kx8 structure has 96 bytes
|
||||
// Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure
|
||||
// For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures
|
||||
uint8_t s[8], m[8];
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 8; j++) {
|
||||
s[j] = in[j].scales[i] & 63;
|
||||
m[j] = in[j].scales[i + 4] & 63;
|
||||
}
|
||||
|
||||
out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
|
||||
out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
|
||||
out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
|
||||
out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
|
||||
out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
|
||||
out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
|
||||
out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
|
||||
out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
|
||||
out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
|
||||
out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
|
||||
out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
|
||||
out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
|
||||
}
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 8; j++) {
|
||||
s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15);
|
||||
m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4);
|
||||
}
|
||||
|
||||
out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
|
||||
out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
|
||||
out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
|
||||
out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
|
||||
out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
|
||||
out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
|
||||
out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
|
||||
out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
|
||||
out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
|
||||
out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
|
||||
out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
|
||||
out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
@@ -1718,6 +1995,38 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
|
||||
int interleave_block,
|
||||
const void * GGML_RESTRICT data,
|
||||
size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 8;
|
||||
|
||||
block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data;
|
||||
const block_q5_K * src = (const block_q5_K *) data;
|
||||
block_q5_K dst_tmp[8];
|
||||
int nrow = ggml_nrows(t);
|
||||
int nblocks = t->ne[0] / QK_K;
|
||||
|
||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K));
|
||||
|
||||
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
*dst++ = make_block_q5_Kx8(dst_tmp, interleave_block);
|
||||
}
|
||||
src += nrows_interleaved * nblocks;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
@@ -1936,6 +2245,10 @@ template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * da
|
||||
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
||||
}
|
||||
@@ -1973,6 +2286,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
||||
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -1981,8 +2298,8 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
||||
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
@@ -2013,20 +2330,24 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
||||
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
@@ -2432,6 +2753,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
// instance for Q5_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
|
||||
|
||||
// instance for Q2
|
||||
static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
|
||||
|
||||
@@ -2482,6 +2806,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
return &q2_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q5_K) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q5_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
||||
if (ggml_cpu_has_avx2()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
||||
@@ -44,6 +44,7 @@ struct block_q4_Kx8 {
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
||||
|
||||
struct block_q2_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
@@ -52,6 +53,18 @@ struct block_q2_Kx8 {
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
|
||||
|
||||
struct block_q5_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
uint8_t scales[96]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants
|
||||
uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4)
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5,
|
||||
"wrong q5_K block size/padding");
|
||||
|
||||
struct block_q8_Kx4 {
|
||||
float d[4]; // delta
|
||||
int8_t qs[QK_K * 4]; // quants
|
||||
@@ -82,20 +95,22 @@ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTR
|
||||
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -111,17 +126,19 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/cub.cuh>
|
||||
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1)
|
||||
# define STRIDED_ITERATOR_AVAILABLE
|
||||
# endif
|
||||
using namespace cub;
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
@@ -14,12 +17,14 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef STRIDED_ITERATOR_AVAILABLE
|
||||
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx <= nrows) {
|
||||
offsets[idx] = idx * ncols;
|
||||
}
|
||||
}
|
||||
#endif // STRIDED_ITERATOR_AVAILABLE
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
@@ -31,19 +36,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
cudaStream_t stream) {
|
||||
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
|
||||
int * temp_indices = temp_indices_alloc.get();
|
||||
float * temp_keys = temp_keys_alloc.get();
|
||||
int * d_offsets = offsets_alloc.get();
|
||||
|
||||
static const int block_size = 256;
|
||||
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
|
||||
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
|
||||
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||
|
||||
#ifdef STRIDED_ITERATOR_AVAILABLE
|
||||
auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
|
||||
#else
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
int * offset_iterator = offsets_alloc.get();
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);
|
||||
#endif
|
||||
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
@@ -57,7 +65,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
d_offsets, d_offsets + 1, stream);
|
||||
offset_iterator, offset_iterator + 1, stream);
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
@@ -66,7 +74,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
||||
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,7 +89,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
@@ -89,8 +98,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||
stream);
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1123,6 +1123,7 @@ struct ggml_tensor_extra_gpu {
|
||||
struct ggml_cuda_graph_node_properties {
|
||||
void * node_address;
|
||||
ggml_op node_op;
|
||||
int32_t flags;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
void * src_address[GGML_MAX_SRC];
|
||||
@@ -1326,10 +1327,44 @@ struct ggml_backend_cuda_context {
|
||||
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
|
||||
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||
|
||||
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
||||
|
||||
int curr_stream_no = 0;
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// Map from first_node_ptr to cuda_graph - allows multiple graphs per context
|
||||
// when the computation is split across CPU/GPU (e.g., with --n-cpu-moe)
|
||||
std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs;
|
||||
|
||||
ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {
|
||||
auto it = cuda_graphs.find(first_node_ptr);
|
||||
if (it == cuda_graphs.end()) {
|
||||
cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>();
|
||||
return cuda_graphs[first_node_ptr].get();
|
||||
}
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
// Check if any CUDA graph is enabled for this context (used by kernels that need to know
|
||||
// if graphs are in use without having access to the specific graph key)
|
||||
bool any_cuda_graph_enabled() const {
|
||||
for (const auto & [key, graph] : cuda_graphs) {
|
||||
if (graph && graph->is_enabled()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if any CUDA graph has an instance for this context
|
||||
bool any_cuda_graph_has_instance() const {
|
||||
for (const auto & [key, graph] : cuda_graphs) {
|
||||
if (graph && graph->instance != nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
explicit ggml_backend_cuda_context(int device) :
|
||||
device(device),
|
||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
|
||||
@@ -629,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max(
|
||||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
|
||||
const int nbatch_fa) {
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
|
||||
const int ne11, const int ne12, const int nbatch_fa) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
@@ -641,11 +641,14 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
|
||||
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
||||
|
||||
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
@@ -654,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
return;
|
||||
}
|
||||
|
||||
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
||||
const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
|
||||
const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
||||
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
||||
|
||||
if (jt*ncols1 + j >= ne01) {
|
||||
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
||||
|
||||
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val = 0.0f;
|
||||
@@ -681,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
@@ -778,13 +785,11 @@ void launch_fattn(
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
const bool is_mla = DV == 512; // TODO better parameterization
|
||||
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(V || is_mla);
|
||||
const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
@@ -794,9 +799,9 @@ void launch_fattn(
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
|
||||
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
||||
GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT(K->nb[0] == ggml_element_size(K));
|
||||
GGML_ASSERT(V->nb[0] == ggml_element_size(V));
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
|
||||
@@ -817,10 +822,10 @@ void launch_fattn(
|
||||
size_t nb12 = K->nb[2];
|
||||
size_t nb13 = K->nb[3];
|
||||
|
||||
const char * V_data = V ? (const char *) V->data : nullptr;
|
||||
size_t nb21 = V ? V->nb[1] : nb11;
|
||||
size_t nb22 = V ? V->nb[2] : nb12;
|
||||
size_t nb23 = V ? V->nb[3] : nb13;
|
||||
const char * V_data = (const char *) V->data;
|
||||
size_t nb21 = V->nb[1];
|
||||
size_t nb22 = V->nb[2];
|
||||
size_t nb23 = V->nb[3];
|
||||
|
||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||
const size_t bs = ggml_blck_size(K->type);
|
||||
@@ -849,36 +854,45 @@ void launch_fattn(
|
||||
K_data = (char *) K_f16.ptr;
|
||||
}
|
||||
|
||||
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
if (ggml_is_contiguously_allocated(V)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
if (V_is_K_view) {
|
||||
V_data = K_data;
|
||||
nb21 = nb11;
|
||||
nb22 = nb12;
|
||||
nb23 = nb13;
|
||||
} else {
|
||||
GGML_ASSERT(V->nb[0] == ts);
|
||||
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
||||
const int64_t s01 = nb21 / ts;
|
||||
const int64_t s02 = nb22 / ts;
|
||||
const int64_t s03 = nb23 / ts;
|
||||
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
nb21 = V->ne[0] * sizeof(half);
|
||||
nb22 = V->ne[1] * nb21;
|
||||
nb23 = V->ne[2] * nb22;
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
if (ggml_is_contiguously_allocated(V)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
} else {
|
||||
GGML_ASSERT(V->nb[0] == ts);
|
||||
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
||||
const int64_t s01 = nb21 / ts;
|
||||
const int64_t s02 = nb22 / ts;
|
||||
const int64_t s03 = nb23 / ts;
|
||||
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
|
||||
nb21 = V->ne[0] * sizeof(half);
|
||||
nb22 = V->ne[1] * nb21;
|
||||
nb23 = V->ne[2] * nb22;
|
||||
}
|
||||
V_data = (char *) V_f16.ptr;
|
||||
}
|
||||
V_data = (char *) V_f16.ptr;
|
||||
}
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
|
||||
const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
|
||||
|
||||
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
||||
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
||||
@@ -953,7 +967,7 @@ void launch_fattn(
|
||||
|
||||
blocks_num.x = ntiles_x;
|
||||
blocks_num.y = parallel_blocks;
|
||||
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
|
||||
blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
@@ -1007,7 +1021,7 @@ void launch_fattn(
|
||||
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
|
||||
@@ -400,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
|
||||
bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
||||
bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
||||
typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
@@ -432,7 +432,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = get_cols_per_thread();
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
||||
@@ -442,8 +442,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
||||
|
||||
const int k_VKQ_0 = kb0 * nbatch_fa;
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
@@ -456,7 +455,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
|
||||
static_assert(!mla, "multi-stage loading not implemented for MLA");
|
||||
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
||||
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
||||
constexpr bool use_cp_async = true;
|
||||
cp_async_wait_all();
|
||||
@@ -471,8 +470,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
|
||||
// For MLA K and V have the same data.
|
||||
// Therefore, iterate over K in reverse and later re-use the data if possible.
|
||||
#pragma unroll
|
||||
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
|
||||
for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
|
||||
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
||||
const int k0_diff = k0_stop - k0_start;
|
||||
|
||||
@@ -510,7 +511,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
||||
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
||||
@@ -522,14 +522,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
T_A_KQ K_A;
|
||||
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||
|
||||
// Wide version of KQ_C is column-major
|
||||
if constexpr (cols_per_warp == 8) {
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
} else {
|
||||
// Wide version of KQ_C is column-major
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
// RDNA matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -773,6 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
||||
// Preload K tile for next iteration:
|
||||
constexpr bool use_cp_async = true;
|
||||
cp_async_wait_all();
|
||||
@@ -788,10 +793,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
|
||||
|
||||
// For MLA K and V have the same data.
|
||||
// Therefore, iterate over V in reverse and re-use the data if possible.
|
||||
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
||||
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
||||
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
||||
T_A_VKQ A_identity;
|
||||
make_identity_mat(A_identity);
|
||||
@@ -799,12 +800,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
|
||||
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
|
||||
#pragma unroll
|
||||
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
||||
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
|
||||
static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
|
||||
const int i0_stop = i0_start + 2*nbatch_V2;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
|
||||
if constexpr (nstages <= 1) {
|
||||
if (i0_start < reusable_cutoff) {
|
||||
if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
|
||||
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
|
||||
@@ -814,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
||||
const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
||||
@@ -917,7 +919,7 @@ template<int ncols> struct mma_tile_sizes {
|
||||
};
|
||||
#endif // defined(TURING_MMA_AVAILABLE)
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
@@ -931,6 +933,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float logit_softcap,
|
||||
const uint3 ne01,
|
||||
const int ne02,
|
||||
const int gqa_ratio,
|
||||
const int ne11,
|
||||
const int stride_Q1,
|
||||
const int stride_Q2,
|
||||
@@ -938,6 +941,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int stride_V,
|
||||
const int stride_mask,
|
||||
const int jt,
|
||||
const int zt_gqa,
|
||||
const int kb0_start,
|
||||
const int kb0_stop) {
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
@@ -953,7 +957,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = get_cols_per_thread();
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
||||
@@ -971,8 +975,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
||||
|
||||
extern __shared__ half2 tile_Q[];
|
||||
@@ -1021,7 +1024,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int j = jc / ncols2;
|
||||
const int c = jc % ncols2;
|
||||
|
||||
if (jt*ncols1 + j < int(ne01.z)) {
|
||||
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
@@ -1076,7 +1079,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr bool last_iter = false;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
@@ -1085,7 +1088,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr bool last_iter = true;
|
||||
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
@@ -1096,7 +1099,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr bool last_iter = false;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
@@ -1105,7 +1108,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr bool last_iter = true;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
@@ -1407,7 +1410,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int j_dst = jc_dst / ncols2;
|
||||
const int c_dst = jc_dst % ncols2;
|
||||
|
||||
if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
|
||||
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1446,14 +1449,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
|
||||
scale, slope, logit_softcap, ne01, ne02,
|
||||
scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
|
||||
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
|
||||
jt, kb0_start, kb0_stop);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
|
||||
__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
@@ -1484,6 +1487,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if (ncols1*ncols2 < 32) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
if (ncols1*ncols2 > 32) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -1498,8 +1508,6 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
|
||||
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
||||
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
|
||||
@@ -1512,14 +1520,15 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int stride_K = nb11 / sizeof(half2);
|
||||
const int stride_mask = nb31 / sizeof(half);
|
||||
|
||||
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
||||
const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
|
||||
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
||||
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
|
||||
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
||||
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
||||
@@ -1530,22 +1539,24 @@ static __global__ void flash_attn_ext_f16(
|
||||
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
||||
|
||||
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
|
||||
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
||||
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
||||
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
|
||||
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
||||
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
||||
|
||||
const int head0 = zt * ncols2;
|
||||
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
||||
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
||||
(const half *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
||||
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
if (KV_max) {
|
||||
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
|
||||
@@ -1553,14 +1564,14 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
||||
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
||||
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
||||
}
|
||||
|
||||
kbc += iter_k;
|
||||
@@ -1574,22 +1585,24 @@ static __global__ void flash_attn_ext_f16(
|
||||
return;
|
||||
}
|
||||
|
||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
|
||||
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
||||
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
|
||||
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
|
||||
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
||||
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
||||
|
||||
const int head0 = zt * ncols2;
|
||||
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
||||
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
||||
(const half *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
||||
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
if (KV_max) {
|
||||
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
|
||||
@@ -1597,9 +1610,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
||||
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
@@ -1633,7 +1646,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
|
||||
constexpr bool mla = DKQ == 576;
|
||||
constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
|
||||
|
||||
const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
||||
@@ -1658,7 +1671,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
||||
|
||||
#if !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
@@ -1669,7 +1682,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
#endif // !defined(GGML_USE_MUSA)
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
||||
|
||||
#if !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
@@ -1728,3 +1741,10 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
|
||||
// For GLM 4.7 Flash
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
|
||||
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
|
||||
return 0;
|
||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
@@ -1187,6 +1195,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
|
||||
@@ -18,9 +18,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
||||
}
|
||||
}
|
||||
|
||||
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
if constexpr (ncols2 <= 16) {
|
||||
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
|
||||
@@ -33,6 +35,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
||||
|
||||
template <int DKQ, int DV>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
@@ -46,7 +49,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
// are put into the template specialization without GQA optimizations.
|
||||
bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
for (const ggml_tensor * t : {Q, K, V, mask}) {
|
||||
if (t == nullptr) {
|
||||
if (t == nullptr || ggml_is_quantized(t->type)) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
@@ -60,17 +63,38 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
// On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:
|
||||
if (cc == GGML_CUDA_CC_VOLTA) {
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio > 4) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
if (use_gqa_opt && gqa_ratio > 2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
if (use_gqa_opt && gqa_ratio > 1) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
@@ -79,6 +103,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
@@ -121,8 +146,34 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
if (gqa_ratio == 20) { // GLM 4.7 Flash
|
||||
if (cc >= GGML_CUDA_CC_BLACKWELL) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
break;
|
||||
}
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
if (Q->ne[1] <= 4) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
break;
|
||||
}
|
||||
if (cc >= GGML_CUDA_CC_TURING) {
|
||||
if (Q->ne[1] <= 4) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
break;
|
||||
}
|
||||
// Volta:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
} else if (gqa_ratio % 16 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -230,9 +281,9 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
|
||||
// The effective batch size for the kernel can be increased by gqa_ratio.
|
||||
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
|
||||
bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
for (const ggml_tensor * t : {Q, K, V, mask}) {
|
||||
if (t == nullptr) {
|
||||
if (t == nullptr || ggml_is_quantized(t->type)) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
@@ -243,6 +294,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
}
|
||||
|
||||
const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
switch (K->ne[0]) {
|
||||
@@ -262,7 +315,10 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!V_is_K_view) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -2918,6 +2918,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
|
||||
props->node_address = node->data;
|
||||
props->node_op = node->op;
|
||||
props->flags = node->flags;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
props->ne[i] = node->ne[i];
|
||||
props->nb[i] = node->nb[i];
|
||||
@@ -2961,21 +2962,32 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
|
||||
return cgraph->nodes[0];
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
|
||||
bool res = false;
|
||||
|
||||
if (cuda_ctx->cuda_graph->instance == nullptr) {
|
||||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
if (graph->instance == nullptr) {
|
||||
res = true;
|
||||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
res = true;
|
||||
cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
}
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
@@ -2983,37 +2995,38 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool props_match = true;
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
|
||||
}
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
bool props_match= true;
|
||||
bool props_match = true;
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]);
|
||||
}
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||
cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
|
||||
#else
|
||||
cudaGraphNode_t errorNode;
|
||||
cudaGraphExecUpdateResult result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||
cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
|
||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||
@@ -3024,14 +3037,14 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c
|
||||
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||
// so instead clear error and re-instantiate
|
||||
(void)cudaGetLastError();
|
||||
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
||||
cuda_ctx->cuda_graph->instance = nullptr;
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
|
||||
graph->instance = nullptr;
|
||||
CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
|
||||
} else {
|
||||
GGML_ASSERT(stat == cudaSuccess);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||
const ggml_tensor * view,
|
||||
@@ -3236,7 +3249,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
return false;
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
|
||||
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
// flag used to determine whether it is an integrated_gpu
|
||||
@@ -3378,6 +3391,9 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// start of fusion operations
|
||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||
@@ -3687,13 +3703,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
|
||||
if (cuda_ctx->cuda_graph->graph != nullptr) {
|
||||
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
|
||||
cuda_ctx->cuda_graph->graph = nullptr;
|
||||
if (graph->graph != nullptr) {
|
||||
CUDA_CHECK(cudaGraphDestroy(graph->graph));
|
||||
graph->graph = nullptr;
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
|
||||
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
||||
|
||||
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
|
||||
@@ -3706,40 +3723,39 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (graph->instance == nullptr) { // Create executable graph from captured graph.
|
||||
CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
|
||||
}
|
||||
if (cuda_graph_update_required) { // Update graph executable
|
||||
ggml_cuda_graph_update_executable(cuda_ctx);
|
||||
ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
|
||||
}
|
||||
// Launch graph
|
||||
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||
CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
|
||||
#else
|
||||
graph_evaluated_or_captured = true;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
|
||||
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
if (cuda_ctx->cuda_graph == nullptr) {
|
||||
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
|
||||
}
|
||||
|
||||
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
||||
if (graph->graph == nullptr) {
|
||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
|
||||
if (!cuda_ctx->cuda_graph->disable_due_to_gpu_arch) {
|
||||
if (!graph->disable_due_to_gpu_arch) {
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
|
||||
}
|
||||
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
|
||||
graph->disable_due_to_gpu_arch = true;
|
||||
}
|
||||
}
|
||||
|
||||
return cuda_ctx->cuda_graph->is_enabled();
|
||||
return graph->is_enabled();
|
||||
#else
|
||||
GGML_UNUSED(cuda_ctx);
|
||||
GGML_UNUSED(graph_key);
|
||||
return false;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
}
|
||||
@@ -3751,15 +3767,19 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
|
||||
bool use_cuda_graph = false;
|
||||
bool cuda_graph_update_required = false;
|
||||
const void * graph_key = nullptr;
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
|
||||
graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
|
||||
if (cuda_ctx->cuda_graph->is_enabled()) {
|
||||
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (graph->is_enabled()) {
|
||||
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
|
||||
|
||||
cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
|
||||
graph->record_update(use_cuda_graph, cuda_graph_update_required);
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
@@ -3773,7 +3793,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||
}
|
||||
|
||||
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
|
||||
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
@@ -3806,7 +3826,14 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
|
||||
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
|
||||
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
#else
|
||||
const bool use_cuda_graph = false;
|
||||
GGML_UNUSED(cuda_ctx);
|
||||
GGML_UNUSED(cgraph);
|
||||
#endif
|
||||
|
||||
static bool enable_graph_optimization = [] {
|
||||
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
|
||||
@@ -4849,6 +4876,16 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (!initialized) {
|
||||
// Set CUDA_SCALE_LAUNCH_QUEUES before any CUDA API call to improve multi-GPU pipeline parallelism performance
|
||||
// PR: https://github.com/ggml-org/llama.cpp/pull/19042
|
||||
if (getenv("CUDA_SCALE_LAUNCH_QUEUES") == nullptr) {
|
||||
#ifdef _WIN32
|
||||
_putenv_s("CUDA_SCALE_LAUNCH_QUEUES", "4x");
|
||||
#else
|
||||
setenv("CUDA_SCALE_LAUNCH_QUEUES", "4x", 0); // don't overwrite if already set
|
||||
#endif // _WIN32
|
||||
}
|
||||
|
||||
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
|
||||
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
|
||||
|
||||
|
||||
@@ -31,14 +31,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
#endif // USE_CUDA_GRAPH
|
||||
if ((nrows == 1) &&
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// CUDA_GRAPHS_DISABLED
|
||||
((ncols > 65536) &&
|
||||
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.cuda_graph->is_enabled())) ||
|
||||
// CUDA_GRAPHS ENABLED
|
||||
((ncols > 32768) &&
|
||||
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.cuda_graph->is_enabled()))) {
|
||||
// Determine if CUDA graphs are effectively disabled for this context
|
||||
// (no graph instance exists and we're not capturing, OR graphs are explicitly enabled)
|
||||
(((ncols > 65536) &&
|
||||
(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
|
||||
ctx.any_cuda_graph_enabled())) ||
|
||||
// CUDA graphs are enabled - use lower threshold
|
||||
((ncols > 32768) &&
|
||||
!(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
|
||||
ctx.any_cuda_graph_enabled())))) {
|
||||
#else
|
||||
(ncols > 65536)) {
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
|
||||
@@ -71,7 +71,7 @@ for type_k in TYPES_KV:
|
||||
f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v))
|
||||
|
||||
for ncols in [8, 16, 32, 64]:
|
||||
for ncols2 in [1, 2, 4, 8, 16]:
|
||||
for ncols2 in [1, 2, 4, 8, 16, 32]:
|
||||
if ncols2 > ncols:
|
||||
continue
|
||||
ncols1 = ncols // ncols2
|
||||
@@ -83,9 +83,9 @@ for ncols in [8, 16, 32, 64]:
|
||||
continue
|
||||
if head_size_kq == 72:
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 == 16:
|
||||
if head_size_kq != 576 and ncols2 in (16, 32):
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 != 16:
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
|
||||
continue
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/cub.cuh>
|
||||
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
|
||||
# include <cuda/iterator>
|
||||
# define CUB_TOP_K_AVAILABLE
|
||||
using namespace cub;
|
||||
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
|
||||
|
||||
@@ -2497,6 +2497,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t flags = 0;
|
||||
|
||||
// skip quantizer if src1 is reused
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <assert.h>
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
@@ -111,7 +111,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * v (float)
|
||||
// MAD: y (F32) += x (F16) * s (float)
|
||||
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
|
||||
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
@@ -318,9 +318,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
uint32_t ic = 0;
|
||||
|
||||
// Process in blocks of 32 (VLEN_FP32)
|
||||
for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
|
||||
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
|
||||
HVX_Vector_x4 scores_x4;
|
||||
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
|
||||
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
// 1. Compute scores
|
||||
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
||||
float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE];
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||
@@ -356,36 +359,43 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
// 4. Online Softmax Update
|
||||
HVX_Vector v_max = hvx_vec_reduce_max_f32(scores);
|
||||
float m_block = hvx_vec_get_f32(v_max);
|
||||
scores_x4.v[iv] = scores;
|
||||
v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
|
||||
}
|
||||
|
||||
{
|
||||
// 4. Online Softmax Update
|
||||
v_max = hvx_vec_reduce_max_f32(v_max);
|
||||
float m_block = hvx_vec_get_f32(v_max);
|
||||
float M_old = M;
|
||||
float M_new = (m_block > M) ? m_block : M;
|
||||
M = M_new;
|
||||
|
||||
float ms = expf(M_old - M_new);
|
||||
|
||||
const float ms = expf(M_old - M_new);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
S = S * ms;
|
||||
|
||||
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
|
||||
HVX_Vector scores = scores_x4.v[iv];
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
|
||||
HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P);
|
||||
float p_sum = hvx_vec_get_f32(p_sum_vec);
|
||||
S += p_sum;
|
||||
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
|
||||
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
}
|
||||
}
|
||||
|
||||
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
|
||||
S = S * ms + hvx_vec_get_f32(p_sum_vec);
|
||||
}
|
||||
|
||||
// Leftover
|
||||
|
||||
@@ -611,6 +611,9 @@ static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const in
|
||||
if (node->op != ops[i]) {
|
||||
return false;
|
||||
}
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
return false;
|
||||
}
|
||||
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -785,8 +785,12 @@ ggml_metal_device_t ggml_metal_device_init(void) {
|
||||
dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
|
||||
|
||||
dev->props.max_buffer_size = dev->mtl_device.maxBufferLength;
|
||||
dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize;
|
||||
dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength;
|
||||
if (@available(macOS 10.12, iOS 16.0, *)) {
|
||||
dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize;
|
||||
} else {
|
||||
dev->props.max_working_set_size = dev->mtl_device.maxBufferLength;
|
||||
}
|
||||
|
||||
strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1);
|
||||
|
||||
@@ -1078,12 +1082,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
op->src[0]->ne[0] != 256) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek sizes
|
||||
// TODO: disabled for now, until optmized
|
||||
op->src[0]->ne[0] != 256 &&
|
||||
op->src[0]->ne[0] != 576) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
|
||||
@@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ABORT("unsupported op");
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int n_fuse = 1;
|
||||
|
||||
// check if the current node can run concurrently with other nodes before it
|
||||
@@ -2516,7 +2520,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
int32_t nsg = 4;
|
||||
int32_t nsg = ne00 >= 512 ? 8 : 4;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
|
||||
@@ -5552,9 +5552,7 @@ void kernel_flash_attn_ext_impl(
|
||||
|
||||
constexpr short NC = (C/8)/NSG;
|
||||
|
||||
// note: do not unroll for large heads
|
||||
#pragma unroll (DK <= 64 ? NC : 1)
|
||||
for (short cc = 0; cc < NC; ++cc) {
|
||||
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
|
||||
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
||||
|
||||
if (DK % 16 != 0) {
|
||||
@@ -5575,7 +5573,9 @@ void kernel_flash_attn_ext_impl(
|
||||
k8x8_t mk[2];
|
||||
q8x8_t mq[2];
|
||||
|
||||
FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
|
||||
// note: too much unroll can tank the performance for large heads
|
||||
#pragma unroll (MIN(DK8/2, 4*NSG))
|
||||
for (short i = 0; i < DK8/2; ++i) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
|
||||
@@ -5749,7 +5749,9 @@ void kernel_flash_attn_ext_impl(
|
||||
pv += 8*NS20;
|
||||
}
|
||||
} else {
|
||||
FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
|
||||
constexpr short NC = (C/8)/2;
|
||||
|
||||
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
|
||||
s8x8_t vs[2];
|
||||
|
||||
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
|
||||
@@ -5952,6 +5954,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
|
||||
@@ -57,6 +57,7 @@ set(GGML_OPENCL_KERNELS
|
||||
add
|
||||
add_id
|
||||
argsort
|
||||
tri
|
||||
fill
|
||||
clamp
|
||||
cpy
|
||||
@@ -84,7 +85,8 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mv_q4_0_f32_8x_flat
|
||||
mul_mv_q4_0_f32_1d_8x_flat
|
||||
mul_mv_q4_0_f32_1d_16x_flat
|
||||
mul_mv_q6_k
|
||||
mul_mv_q6_k_f32
|
||||
mul_mv_q6_k_f32_flat
|
||||
mul_mv_q8_0_f32
|
||||
mul_mv_q8_0_f32_flat
|
||||
mul_mv_mxfp4_f32
|
||||
|
||||
@@ -398,6 +398,7 @@ struct ggml_backend_opencl_context {
|
||||
int adreno_wave_size;
|
||||
|
||||
cl_bool non_uniform_workgroups;
|
||||
size_t image_max_buffer_size;
|
||||
|
||||
cl_context context;
|
||||
cl_command_queue queue;
|
||||
@@ -407,6 +408,10 @@ struct ggml_backend_opencl_context {
|
||||
ggml_cl_buffer prealloc_scales_trans;
|
||||
ggml_cl_buffer prealloc_act_trans;
|
||||
|
||||
// prealloc buffers for src0 and src1
|
||||
ggml_cl_buffer prealloc_src0;
|
||||
ggml_cl_buffer prealloc_src1;
|
||||
|
||||
cl_program program_add;
|
||||
cl_program program_add_id;
|
||||
cl_program program_clamp;
|
||||
@@ -489,6 +494,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
|
||||
cl_kernel kernel_relu;
|
||||
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
|
||||
cl_kernel kernel_tri;
|
||||
cl_kernel kernel_fill;
|
||||
cl_kernel kernel_clamp;
|
||||
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
|
||||
@@ -527,8 +533,10 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
|
||||
cl_kernel kernel_convert_block_q4_0_noshuffle;
|
||||
cl_kernel kernel_restore_block_q4_0_noshuffle;
|
||||
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
|
||||
cl_kernel kernel_mul_mv_q6_K_f32;
|
||||
cl_kernel kernel_mul_mv_q6_K_f32_flat;
|
||||
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
|
||||
cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat;
|
||||
cl_kernel kernel_solve_tri_f32;
|
||||
@@ -793,6 +801,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// tri
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "tri.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("tri.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_tri = clCreateKernel(prog, "kernel_tri_f32", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
}
|
||||
|
||||
// fill
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -868,6 +894,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
@@ -1090,14 +1118,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q6_k
|
||||
// mul_mv_q6_k_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_q6_k.cl.h"
|
||||
#include "mul_mv_q6_k_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_q6_k.cl");
|
||||
const std::string kernel_src = read_file("mul_mv_q6_k_f32.cl");
|
||||
#endif
|
||||
backend_ctx->program_mul_mv_q6_K =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
@@ -1106,6 +1134,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q6_k_f32_flat
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_q6_k_f32_flat.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_q6_k_f32_flat.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q6_K_f32_flat", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q8_0_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -2639,6 +2684,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
|
||||
GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);
|
||||
|
||||
clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL);
|
||||
GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size);
|
||||
|
||||
clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL);
|
||||
GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size);
|
||||
|
||||
@@ -2892,6 +2940,50 @@ struct ggml_tensor_extra_cl_q8_0 {
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor_extra_cl_q6_K {
|
||||
// Lower 4 bits of quantized weights.
|
||||
cl_mem ql = nullptr;
|
||||
// Upper 2 bits of quantized weights.
|
||||
cl_mem qh = nullptr;
|
||||
// Scales for each block.
|
||||
cl_mem s = nullptr;
|
||||
// Scales for each super block.
|
||||
cl_mem d = nullptr;
|
||||
|
||||
size_t size_ql = 0;
|
||||
size_t size_qh = 0;
|
||||
size_t size_s = 0;
|
||||
size_t size_d = 0;
|
||||
|
||||
~ggml_tensor_extra_cl_q6_K() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
if (ql != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(ql));
|
||||
ql = nullptr;
|
||||
}
|
||||
if (qh != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(qh));
|
||||
qh = nullptr;
|
||||
}
|
||||
if (s != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(s));
|
||||
s = nullptr;
|
||||
}
|
||||
if (d != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(d));
|
||||
d = nullptr;
|
||||
}
|
||||
|
||||
size_ql = 0;
|
||||
size_qh = 0;
|
||||
size_s = 0;
|
||||
size_d = 0;
|
||||
}
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Backend API
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -3058,6 +3150,10 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
|
||||
ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||
i += 2;
|
||||
@@ -3201,6 +3297,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_OP_TRI:
|
||||
return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
|
||||
case GGML_OP_FILL:
|
||||
return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
|
||||
case GGML_OP_CLAMP:
|
||||
@@ -3432,6 +3530,12 @@ struct ggml_backend_opencl_buffer_context {
|
||||
for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) {
|
||||
delete e;
|
||||
}
|
||||
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) {
|
||||
delete e;
|
||||
}
|
||||
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {
|
||||
delete e;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() {
|
||||
@@ -3494,6 +3598,21 @@ struct ggml_backend_opencl_buffer_context {
|
||||
return extra;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() {
|
||||
ggml_tensor_extra_cl_q6_K * extra;
|
||||
if (temp_tensor_extras_q6_K.empty()) {
|
||||
extra = new ggml_tensor_extra_cl_q6_K();
|
||||
} else {
|
||||
extra = temp_tensor_extras_q6_K.back();
|
||||
temp_tensor_extras_q6_K.pop_back();
|
||||
}
|
||||
|
||||
temp_tensor_extras_q6_K_in_use.push_back(extra);
|
||||
|
||||
extra->reset();
|
||||
return extra;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {
|
||||
temp_tensor_extras.push_back(e);
|
||||
@@ -3514,6 +3633,11 @@ struct ggml_backend_opencl_buffer_context {
|
||||
temp_tensor_extras_q8_0.push_back(e);
|
||||
}
|
||||
temp_tensor_extras_q8_0_in_use.clear();
|
||||
|
||||
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {
|
||||
temp_tensor_extras_q6_K.push_back(e);
|
||||
}
|
||||
temp_tensor_extras_q6_K_in_use.clear();
|
||||
}
|
||||
|
||||
// Pools for extras. Available extras are in `temp_tensor_extras`. Extras
|
||||
@@ -3529,6 +3653,8 @@ struct ggml_backend_opencl_buffer_context {
|
||||
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0;
|
||||
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K;
|
||||
std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K_in_use;
|
||||
|
||||
// The buffer_context is initially created by ggml_backend_buft_alloc_buffer
|
||||
// before any tensor is initialized (at the beginning of alloc_tensor_range).
|
||||
@@ -4035,6 +4161,92 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q6_K) {
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
|
||||
|
||||
// Allocate the new extra and create aliases from the original.
|
||||
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||
ggml_tensor_extra_cl_q6_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q6_K();
|
||||
|
||||
size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
|
||||
size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4;
|
||||
size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16;
|
||||
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
|
||||
GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) &&
|
||||
"Incorrect tensor size");
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
CL_CHECK(clEnqueueWriteBuffer(
|
||||
queue, data_device, CL_TRUE, 0,
|
||||
ggml_nbytes(tensor), data, 0, NULL, NULL));
|
||||
|
||||
cl_buffer_region region;
|
||||
|
||||
// Subbuffer for ql
|
||||
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
|
||||
region.size = size_ql;
|
||||
extra->ql = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
auto previous_origin = region.origin;
|
||||
|
||||
// Subbuffer for qh
|
||||
region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment);
|
||||
region.size = size_qh;
|
||||
extra->qh = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
previous_origin = region.origin;
|
||||
|
||||
// Subbuffer for scales
|
||||
region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment);
|
||||
region.size = size_s;
|
||||
extra->s = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
previous_origin = region.origin;
|
||||
|
||||
// Create subbuffer for d.
|
||||
region.origin = align_to(previous_origin + size_s, backend_ctx->alignment);
|
||||
region.size = size_d;
|
||||
extra->d = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
previous_origin = region.origin;
|
||||
|
||||
// Flatten the weights
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_q6_K;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
|
||||
extra->size_ql = size_ql;
|
||||
extra->size_qh = size_qh;
|
||||
extra->size_s = size_s;
|
||||
extra->size_d = size_d;
|
||||
|
||||
tensor->extra = extra;
|
||||
return;
|
||||
}
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
|
||||
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
|
||||
@@ -4244,6 +4456,34 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clEnqueueReadBuffer(
|
||||
queue, data_device, CL_TRUE, offset,
|
||||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q6_K) {
|
||||
ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra;
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, &evt));
|
||||
@@ -4686,6 +4926,81 @@ static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct gg
|
||||
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
|
||||
}
|
||||
|
||||
// Copy a noncontiguous tensor to contiguous tensor. ne[] remains the same but
|
||||
// nb[] is recalculated such that tensor is contiguous.
|
||||
static void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor * src, cl_mem dst,
|
||||
cl_ulong &nb0, cl_ulong &nb1, cl_ulong &nb2, cl_ulong &nb3) {
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
const int tensor_type_size = ggml_type_size(src->type);
|
||||
|
||||
const int ne00 = src->ne[0];
|
||||
const int ne01 = src->ne[1];
|
||||
const int ne02 = src->ne[2];
|
||||
const int ne03 = src->ne[3];
|
||||
|
||||
const cl_ulong nb00 = src->nb[0];
|
||||
const cl_ulong nb01 = src->nb[1];
|
||||
const cl_ulong nb02 = src->nb[2];
|
||||
const cl_ulong nb03 = src->nb[3];
|
||||
|
||||
const int ne0 = src->ne[0];
|
||||
const int ne1 = src->ne[1];
|
||||
const int ne2 = src->ne[2];
|
||||
const int ne3 = src->ne[3];
|
||||
|
||||
nb0 = tensor_type_size;
|
||||
nb1 = tensor_type_size*ne00;
|
||||
nb2 = tensor_type_size*ne00*ne01;
|
||||
nb3 = tensor_type_size*ne00*ne01*ne02;
|
||||
|
||||
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra;
|
||||
|
||||
cl_ulong offset0 = extra->offset + src->view_offs;
|
||||
cl_ulong offsetd = 0;
|
||||
|
||||
cl_kernel kernel;
|
||||
|
||||
switch (src->type) {
|
||||
case GGML_TYPE_F32:
|
||||
kernel = backend_ctx->kernel_cpy_f32_f32;
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
kernel = backend_ctx->kernel_cpy_f16_f16;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &dst));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
|
||||
|
||||
const int nth = MIN(64, ne00);
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src);
|
||||
}
|
||||
|
||||
static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
UNUSED(backend);
|
||||
UNUSED(src0);
|
||||
@@ -5961,6 +6276,44 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
UNUSED(src1);
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const int tri_type = ggml_get_op_params_i32(dst, 0);
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
const int ne0 = dst->ne[0];
|
||||
const int ne1 = dst->ne[1];
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_tri;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &tri_type));
|
||||
|
||||
size_t local_work_size[1] = { 256 };
|
||||
size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] };
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
@@ -7619,6 +7972,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
|
||||
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
|
||||
ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
|
||||
ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;
|
||||
#endif
|
||||
|
||||
const int ne00 = src0 ? src0->ne[0] : 0;
|
||||
@@ -7661,9 +8015,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
cl_context context = backend_ctx->context;
|
||||
|
||||
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
|
||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) {
|
||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0 &&
|
||||
// dst is wrapped with image1d_buffer, the size limit applies, also src0
|
||||
(ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4 <= backend_ctx->image_max_buffer_size)) {
|
||||
// For KQ
|
||||
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) &&
|
||||
((nb01 * ne01 / 4)/4 <= backend_ctx->image_max_buffer_size) &&
|
||||
nb00 <= nb02 &&
|
||||
nb02 <= nb01 &&
|
||||
nb01 <= nb03 &&
|
||||
@@ -7674,7 +8031,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
return;
|
||||
}
|
||||
// For KQV
|
||||
if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
|
||||
((nb02 * ne02 / 4)/4 <= backend_ctx->image_max_buffer_size)) {
|
||||
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
@@ -7980,9 +8338,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
|
||||
// GEMM using local memory
|
||||
// Current BK = 16, so ne00 % 16 == 0
|
||||
if (ggml_is_contiguous(src0) &&
|
||||
ggml_is_contiguous(src1) &&
|
||||
src1t == GGML_TYPE_F32 &&
|
||||
if (src1t == GGML_TYPE_F32 &&
|
||||
ne00 % 16 == 0 &&
|
||||
ne11 > 1) {
|
||||
switch(src0t) {
|
||||
@@ -7994,10 +8350,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
cl_mem mem_src0 = extra0->data_device;
|
||||
cl_mem mem_src1 = extra1->data_device;
|
||||
|
||||
cl_ulong nb00_cont = nb00;
|
||||
cl_ulong nb01_cont = nb01;
|
||||
cl_ulong nb02_cont = nb02;
|
||||
cl_ulong nb03_cont = nb03;
|
||||
|
||||
cl_ulong nb10_cont = nb10;
|
||||
cl_ulong nb11_cont = nb11;
|
||||
cl_ulong nb12_cont = nb12;
|
||||
cl_ulong nb13_cont = nb13;
|
||||
|
||||
cl_ulong offset0_cont = offset0;
|
||||
cl_ulong offset1_cont = offset1;
|
||||
|
||||
if (!ggml_is_contiguous(src0)) {
|
||||
backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));
|
||||
ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,
|
||||
nb00_cont, nb01_cont, nb02_cont, nb03_cont);
|
||||
mem_src0 = backend_ctx->prealloc_src0.buffer;
|
||||
offset0_cont = 0;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(src1)) {
|
||||
backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));
|
||||
ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,
|
||||
nb10_cont, nb11_cont, nb12_cont, nb13_cont);
|
||||
mem_src1 = backend_ctx->prealloc_src1.buffer;
|
||||
offset1_cont = 0;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
@@ -8029,10 +8417,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
cl_mem mem_src0 = extra0->data_device;
|
||||
cl_mem mem_src1 = extra1->data_device;
|
||||
|
||||
cl_ulong nb00_cont = nb00;
|
||||
cl_ulong nb01_cont = nb01;
|
||||
cl_ulong nb02_cont = nb02;
|
||||
cl_ulong nb03_cont = nb03;
|
||||
|
||||
cl_ulong nb10_cont = nb10;
|
||||
cl_ulong nb11_cont = nb11;
|
||||
cl_ulong nb12_cont = nb12;
|
||||
cl_ulong nb13_cont = nb13;
|
||||
|
||||
cl_ulong offset0_cont = offset0;
|
||||
cl_ulong offset1_cont = offset1;
|
||||
|
||||
if (!ggml_is_contiguous(src0)) {
|
||||
backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));
|
||||
ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,
|
||||
nb00_cont, nb01_cont, nb02_cont, nb03_cont);
|
||||
mem_src0 = backend_ctx->prealloc_src0.buffer;
|
||||
offset0_cont = 0;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(src1)) {
|
||||
backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));
|
||||
ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,
|
||||
nb10_cont, nb11_cont, nb12_cont, nb13_cont);
|
||||
mem_src1 = backend_ctx->prealloc_src1.buffer;
|
||||
offset1_cont = 0;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
@@ -8060,6 +8480,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
break;
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
|
||||
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
|
||||
|
||||
@@ -8432,14 +8856,49 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat;
|
||||
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 16;
|
||||
nth1 = 2;
|
||||
ndst = 4;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
nth1 = 2;
|
||||
ndst = 4;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r3));
|
||||
#else
|
||||
kernel = backend_ctx->kernel_mul_mv_q6_K_f32;
|
||||
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 2;
|
||||
nth1 = 16;
|
||||
nth0 = 16;
|
||||
nth1 = 2;
|
||||
ndst = 1;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 2;
|
||||
nth1 = 64;
|
||||
nth0 = 64;
|
||||
nth1 = 2;
|
||||
ndst = 1;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
@@ -8459,6 +8918,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
case GGML_TYPE_MXFP4: {
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
@@ -8561,7 +9021,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
} else if (src0t == GGML_TYPE_Q5_K) {
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
} else if (src0t == GGML_TYPE_Q6_K) {
|
||||
size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
|
||||
size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
@@ -10008,6 +10468,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||
}
|
||||
func = ggml_cl_glu;
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cl_tri;
|
||||
break;
|
||||
case GGML_OP_FILL:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user