Compare commits

...

38 Commits

Author SHA1 Message Date
Pascal
4adac43f6f server: tests: fetch random media marker via /apply-template (#21962) (#21980)
* server: tests: fetch random media marker via /apply-template (#21962 fix)

* server: allow pinning media marker via LLAMA_MEDIA_MARKER env var

get_media_marker() checks LLAMA_MEDIA_MARKER at first call and uses it
as-is if set, falling back to the random marker otherwise.

Tests no longer need to fetch the marker dynamically via /apply-template:
the fixture sets LLAMA_MEDIA_MARKER=<__media__> so the hardcoded prompts
work as before.

Address review feedback from ngxson

* server: make get_media_marker() thread-safe via magic statics

Use a C++11 static local with a lambda initializer instead of a global
static with an empty-check. The runtime guarantees initialization exactly
once without explicit locking.

Address review feedback from ggerganov

* nits

* nits
2026-04-16 20:46:21 +03:00
PikaPikachu
9db77a020c model : refactor QKV into common build_qkv and create_tensor_qkv helpers (#21245)
* model : refactor QKV into common build_qkv and create_tensor_qkv helpers

* model : extend build_qkv to bert/mpt/dbrx/olmo/lfm2/nemotron-h/granite-hybrid/gemma3n-iswa/t5-dec and fix wqkv_s
2026-04-16 17:41:34 +02:00
Sigbjørn Skjæret
f772f6e434 model : support NVFP4 tensors for Gemma4 (#21971)
* support nvfp4 tensors for Gemma4

* add wo_s to build_attn

* add wo_s to build_attn

* fix glm4
2026-04-16 16:51:47 +02:00
Ruben Ortlam
b572d1ecd6 codeowners: add team member comments (#21714) 2026-04-16 13:13:11 +03:00
Anav Prasad
03b3d07798 Convert: Fix NemotronH Config Parsing (#21664)
* fix NemotronH vocab loading by using trust_remote_code for unsupported config patterns

* fix NemotronH tokenizer loading by overriding set_vocab with trust_remote_code
2026-04-16 13:11:45 +03:00
Aman Gupta
3f7c29d318 ggml: add graph_reused (#21764)
* ggml: add graph_reused

* use versioning instead of reuse flag

* increment version with atomic

* use top bits for split numbering

* add assert

* move counter to ggml.c

* set uid in split_graph only

* fix windows

* address further review comments

* get next_uid rather than doing bit manipulation

* rename + add comment about uid
2026-04-16 17:21:28 +08:00
Kusha Gharahi
ae2d34899e metal: Implement ROLL op (#21946)
* nix: support unified apple-sdk

* Impl roll op for Metal

* Revert "nix: support unified apple-sdk"

This reverts commit abfa473360.

* update ops.md

* update op docs
2026-04-16 11:54:37 +03:00
rehan-10xengineer
1e796eb41f ggml-cpu: add 128-bit RVV implementation for Quantization Vector Dot (#20633)
* ggml-cpu: add 128-bit impls for i-quants, ternary quants

* ggml-cpu: add 128-bit impls for iq2_xs, iq3_s, iq3_xxs, tq2_0

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>

* ggml-cpu: refactor; add rvv checks

---------

Co-authored-by: taimur-10x <taimur.ahmad@10xengineers.ai>
Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>
2026-04-16 11:15:15 +03:00
rehan-10xengineer
5637536517 ggml : implemented simd_gemm kernel for riscv vector extension (#20627)
Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>
2026-04-16 11:14:26 +03:00
Yuannan
90fb96a7b3 devops : added spirv-headers to nix (#21965) 2026-04-16 11:12:52 +03:00
Reese Levine
82677a6ede ggml-webgpu: compute pass batching and removing profiling overhead (#21873)
* Update register tiling matmul to use f32 accumulation

* fix profiling code

* Fix register tiling matmul for chrome, i'm blaming dawn

* Update batch tuning value for iOS

* compile fix

* Fix use of new load function

* Move to a single query set for GPU profiling

* Move to batching compute passes when not profiling

* Refactor build_multi

* remove iOS throttling now that we're batching compute passes
2026-04-16 11:12:19 +03:00
Ludovic Henry
8612ed18b7 ci : Use ggml-org/ccache-action on RISC-V as well (#21632) 2026-04-16 11:11:25 +03:00
Katostrofik
b1be68e8ca [SYCL] Fix Q8_0 reorder: garbage on 2nd prompt + crash on full VRAM (#21638)
* [SYCL] Fix Q8_0 reorder: add missing dequantize path for GEMM

The Q8_0 reorder optimization (#21527) was missing a reorder-aware
dequantizer for the GEMM code path used during prompt processing.
After token generation reordered Q8_0 weights (via DMMV/MMVQ), the
next prompt processing pass would read them with the standard
dequantizer, producing garbage output.

Add dequantize_block_q8_0_reorder() and wire it into both
ggml_get_to_fp16_sycl() and ggml_get_to_fp32_sycl(), matching the
pattern already used by Q4_0, Q4_K, and Q6_K.

Fixes #21589

AI (Claude) was used to assist with root cause investigation and
writing the kernel code. All code was human-reviewed and tested
on real hardware.

* SYCL: fix reorder crash when device memory is full

The reorder optimization allocates a temporary buffer the full size of
the weight tensor on the device. When VRAM is nearly full (large models
on a single GPU), this allocation fails and the subsequent memcpy crashes
on a NULL pointer.

Fix: try device allocation first, fall back to host memory if device
memory is full. The reorder kernel still works correctly reading from
host memory over PCIe. This is slower for the one-time reorder (~21 t/s
vs ~38 t/s on Intel Arc Pro B70), but the optimization is preserved for
all subsequent inference. If both device and host allocation fail, skip
the reorder and fall back to the unoptimized kernel path.

Also fixes a bug where opt_for_reorder() marked tensors as reordered
even when the reorder was skipped due to allocation failure. This caused
DMMV/MMVQ kernels to read the original AoS data as if it were SoA,
producing garbage output or NaN results.

Tested on Intel Arc Pro B70 (32GB) with Q8_0, Q4_K_M models. Coding was
AI-assisted (Claude), reviewed and tested on hardware by a human.

Fixes #20478

* SYCL: add RAII temp buffer class + macro guard for host fallback

Replace sycl_ext_malloc_with_fallback/sycl_ext_free_fallback free
functions with sycl_reorder_temp_buffer RAII class. The host_fallback
bool is now a private member, and cleanup happens automatically at
scope exit.

Add GGML_SYCL_HOST_MEM_FALLBACK cmake option (default ON) to guard
the host memory fallback code path. Device access to host memory
requires Linux kernel 6.8+ (Ubuntu 26.04+); users on older kernels
can set -DGGML_SYCL_HOST_MEM_FALLBACK=OFF to disable it.

Addresses arthw's review on PR #21638.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* SYCL: document GGML_SYCL_HOST_MEM_FALLBACK build option in SYCL.md

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* SYCL: add reorder-aware DMMV dequantizers for Q4_K and Q6_K

Q4_K and Q6_K had reorder support for MMVQ and GEMM paths but not
DMMV. When the DMMV path encountered reordered data it would abort.

Add DMMV kernels that read from the SOA reorder layout for both
types. Same math as the non-reorder versions, different memory
access pattern.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 08:34:05 +03:00
Xuan-Son Nguyen
408225bb1a server: use random media marker (#21962)
* server: use random media marker

* nits

* remove legacy <__image__> token

* revert special char in random
2026-04-15 23:52:22 +02:00
Ruben Ortlam
b3d758750a vulkan: optimize im2col (#21713)
* vulkan: improve im2col memory write layout

* cap workgroups

* minimal device tuning

* use vendor_id instead of subgroup size
2026-04-15 19:04:51 +02:00
Pasha Khosravi
7e72b38bc1 cuda: Q1_0 initial backend (#21629)
* [cuda] initial Q1_0 backend

* remove unused code, fix AMD MMA guard

* attempt to support dp4a

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-04-15 18:38:38 +02:00
Reese Levine
20d3bc2cc8 ggml-webgpu: Fix dequantization helpers to not pass in pointers (#21872)
* Fix dequantization helpers to not pass in pointers

* Increase XIELU precision
2026-04-15 09:14:40 -07:00
Johannes Gäßler
a6206958d2 CUDA: require explicit opt-in for P2P access (#21910) 2026-04-15 16:01:46 +02:00
Johannes Gäßler
014dca49d6 CUDA: manage NCCL communicators in context (#21891)
* CUDA: manage NCCL communicators in context

* add check that all backends are CUDA

* remove unused vector, limit init to > 1 GPUs

* fix warnings

* fix cuda device, cache allreduce
2026-04-15 15:58:40 +02:00
Valeriy Dubov
adb541a6ad rpc : add native RDMA transport for RPC backend (RoCEv2) (#20590) 2026-04-15 16:44:02 +03:00
Xuan-Son Nguyen
80d8770804 docs: more extensive RoPE documentation [no ci] (#21953)
* more extensive ggml_rope documentation

* add more docs

* nits
2026-04-15 14:45:16 +02:00
Ruben Ortlam
8dc530b86d ci: disable test-backend-ops on Vulkan llvmpipe run and resture default timeout (#21901) 2026-04-15 10:55:21 +02:00
Piotr Wilkin (ilintar)
e1a9a6dcbe autoparser: support case of JSON_NATIVE with per-call markers (test case: Reka-Edge) (#21892) 2026-04-15 10:51:50 +02:00
Matt
e39eba26f3 read n_ctx back after making llama_context (#21939) 2026-04-15 15:24:57 +08:00
Yiwei Shao
5d14e5d19b hexagon: optimization for HMX mat_mul (#21554)
* hexagon: add async HMX worker

Introduce hmx-worker (dedicated thread for HMX compute) to overlap HMX
matmul with HVX dequant/DMA stages in the pipeline path, replacing the
previous synchronous HMX calls that blocked the main thread.

* hexagon: cost-based VTCM chunk search for out-stationary matmul

* hexagon: fix futex race in hmx_worker_drain
Store the boolean to local variable avoid atomic load twice

* hex-mm: hmx optimize scatter/transpose and use HMX intrinsics

* hex-vmem: drop vmem limit a touch under 3GB on v73

* hexagon: add fwd declaration of htp_context

* hex-hmx: replace hmx-worker with hmx-queue that mimics dma-queue interface

Simplifies the overall implemantion, reduces thread wakeup roundtrips.

* hex-mm: add debug log to hmx work func called from hmx-queue

* Update hmx-queue.h

Co-authored-by: Max Krasnyansky <max.krasnyansky@gmail.com>

---------

Co-authored-by: Kim-Chyan Gan <kgan@qti.qualcomm.com>
Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
Co-authored-by: Max Krasnyansky <max.krasnyansky@gmail.com>
2026-04-14 14:09:03 -07:00
Xuan-Son Nguyen
fae3a28070 ggml : remove ggml-ext.h (#21869)
* ggml: correct placement of ggml-ext.h

* ggml : remove ggml-ext.h

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-04-14 17:32:58 +03:00
Georgi Gerganov
c0de6eda72 metal : fix FA support logic (#21898) 2026-04-14 17:32:29 +03:00
Xuan-Son Nguyen
707c0b7a6e mtmd: add mtmd_image_tokens_get_decoder_pos() API (#21851)
* mtmd: add mtmd_image_tokens_get_decoder_pos() API

* consistent naming

* fix build
2026-04-14 16:07:41 +02:00
Jeff Bolz
1f30ac0cea vulkan: Programmatically add RoundingModeRTE to all shaders when the device supports it (#21572)
* vulkan: Programmatically add RoundingModeRTE to all shaders when the device supports it

* use FetchContent to get SPIRV-Headers

* Fetch spirv-headers unconditionally

* remove fetchcontent, rely on installed headers

* fix ubuntu job

* Update docs/build.md
2026-04-14 15:17:45 +02:00
Georgi Gerganov
f4b5bf2f32 ci : re-enable mac workflows (#21894)
* ci : re-enable mac workflows

* vulkan : fix compile warning
2026-04-14 15:58:09 +03:00
Seyoung Jeong
aa0f1897b7 metal : add XIELU unary op (#20802) 2026-04-14 15:43:59 +03:00
Adrien Gallouët
be76dd0bb2 vendor : update BoringSSL to 0.20260413.0 (#21881)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-04-14 14:25:09 +03:00
Richard Davison
2e05f06ffb ggml : fix ARM NEON nvfp4 dot product on non-dotprod targets (#21559) 2026-04-14 14:23:45 +03:00
texasich
acc37a42ea cmake: fix CMP0194 warning on Windows with MSVC (#21630)
* cmake: fix CMP0194 warning on Windows with MSVC

Set CMP0194 policy to NEW before project() call in ggml/CMakeLists.txt to suppress the "MSVC is not an assembler for language ASM" warning introduced in CMake 4.1.

The ggml project enables ASM globally for Metal (macOS) and KleidiAI (ARM) backends. On Windows/MSVC, no assembler sources are used, but CMake 4.1+ warns because cl.exe is not a valid ASM compiler.

This follows the same pattern used in ggml-vulkan (CMP0114, CMP0147).

Closes ggml-org/llama.cpp#20311

* cmake: apply cisc's formatting suggestion

---------

Co-authored-by: texasich <texasich@users.noreply.github.com>
2026-04-14 13:47:56 +03:00
Reese Levine
5a23695d5a ggml-webgpu: Update register tiling matmul to use f32 accumulation (#21644)
* Update register tiling matmul to use f32 accumulation

* fix profiling code

* Fix register tiling matmul for chrome, i'm blaming dawn

* Update batch tuning value for iOS

* compile fix

* Fix use of new load function
2026-04-14 13:46:41 +03:00
Berk Idem
56666fa607 common: skip reasoning budget sampler when no budget is requested (#21870)
* common: skip reasoning budget sampler when no budget is requested

After I added thinking_start_tag / thinking_end_tag for gemma4 in #21697, the reasoning budget sampler gets unconditionally created even when no budget is configured (the default -1). The same applies to kimi_k2, lfm2, lfm2_5, and ministral_3 which also set these tags. The budget gets converted to INT_MAX, so the sampler never actually forces any tokens but still runs per-token checks (start tag matching in IDLE state, token-to-piece conversion + UTF-8 checks in COUNTING state).

More importantly, the mere existence of the sampler (non-null rbudget) disables backend sampling. Backend sampling lets the GPU select tokens directly, avoiding a full logits transfer from GPU to CPU every token. This could explain the 30% speed regression reported in #21784 (98 t/s to 70 t/s on Vulkan).

So I added a reasoning_budget_tokens >= 0 check to the sampler creation condition. When the budget is unlimited, the sampler is not created, backend sampling stays enabled, and no per-token overhead is added. When a budget is explicitly set (0, 128, 1024, etc.), the sampler is created and works as before.

* common: preserve rbudget when grammar is lazy

Following up on the review feedback on #21870: keep the reasoning budget sampler when grammar_lazy is true, so the thinking-block grammar suppression from #20970 still works when tools are in use. This way, we only skip the sampler when both no budget is set AND grammar is not lazy.
2026-04-14 12:43:06 +02:00
Jeff Bolz
6a6780a232 vulkan: Support GGML_TYPE_NVFP4 (#21455)
This adds nvfp4 support for get_rows, dequant, and mul_mat(_id). For
mul_mat, it does not add support for the dp4/q8_1 path, it's all via
fp16/fp32.
2026-04-14 11:34:23 +02:00
Xuan-Son Nguyen
e489a5ca0e server: support OAI /v1/audio/transcriptions API (#21863)
* server: support OAI /v1/audio/transcriptions API

* address autoreview comments

* correct default response_format value
2026-04-14 11:09:52 +02:00
228 changed files with 8568 additions and 4409 deletions

View File

@@ -18,6 +18,7 @@
vulkan-loader,
openssl,
shaderc,
spirv-headers,
useBlas ?
builtins.all (x: !x) [
useCuda
@@ -145,6 +146,7 @@ effectiveStdenv.mkDerivation (finalAttrs: {
ninja
pkg-config
git
spirv-headers
]
++ optionals useCuda [
cudaPackages.cuda_nvcc

View File

@@ -7,7 +7,7 @@ RUN apt update && apt install -y git build-essential cmake wget xz-utils
# Install SSL and Vulkan SDK dependencies
RUN apt install -y libssl-dev curl \
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libvulkan-dev glslc
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libvulkan-dev glslc spirv-headers
# Build it
WORKDIR /app

View File

@@ -47,22 +47,10 @@ jobs:
steps:
- name: Install dependencies
run: |
sudo apt-get update
# Install necessary packages
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 cmake build-essential wget git-lfs
# Set gcc-14 and g++-14 as the default compilers
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100
if ! which rustc; then
# Install Rust stable version
sudo apt-get install -y rustup
rustup install stable
rustup default stable
fi
git lfs install
- name: GCC version check
@@ -74,12 +62,12 @@ jobs:
id: checkout
uses: actions/checkout@v6
# FIXME: Enable when ggml-org/ccache-action works on riscv64
# - name: ccache
# uses: ggml-org/ccache-action@v1.2.21
# with:
# key: ubuntu-riscv64-native-sanitizer-${{ matrix.sanytizer }}-${{ matrix.build_type }}
# save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: ccache
uses: ggml-org/ccache-action@afde29e5b5422e5da23cb1f639e8baecadeadfc3 # https://github.com/ggml-org/ccache-action/pull/1
with:
key: ubuntu-riscv64-native-sanitizer-${{ matrix.sanitizer }}-${{ matrix.build_type }}
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build

View File

@@ -141,61 +141,59 @@ jobs:
# amd-smi static
# GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
# TODO: sandbox Mac runners
# ggml-ci-mac-metal:
# runs-on: [self-hosted, macOS, ARM64]
#
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v6
#
# - name: Test
# id: ggml-ci
# run: |
# GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
#
# ggml-ci-mac-webgpu:
# runs-on: [self-hosted, macOS, ARM64]
#
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v6
#
# - name: Dawn Dependency
# id: dawn-depends
# run: |
# DAWN_VERSION="v2.0.0"
# DAWN_OWNER="reeselevine"
# DAWN_REPO="dawn"
# DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release"
# echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
# curl -L -o artifact.zip \
# "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
# mkdir dawn
# unzip artifact.zip
# tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
#
# - name: Test
# id: ggml-ci
# run: |
# GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \
# bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
#
# ggml-ci-mac-vulkan:
# runs-on: [self-hosted, macOS, ARM64]
#
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v6
#
# - name: Test
# id: ggml-ci
# run: |
# vulkaninfo --summary
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-mac-metal:
runs-on: [self-hosted, macOS, ARM64]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Test
id: ggml-ci
run: |
GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-mac-webgpu:
runs-on: [self-hosted, macOS, ARM64]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Dawn Dependency
id: dawn-depends
run: |
DAWN_VERSION="v20260317.182325"
DAWN_OWNER="google"
DAWN_REPO="dawn"
DAWN_ASSET_NAME="Dawn-18eb229ef5f707c1464cc581252e7603c73a3ef0-macos-latest-Release"
echo "Fetching release asset from https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
curl -L -o artifact.tar.gz \
"https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
mkdir dawn
tar -xvf artifact.tar.gz -C dawn --strip-components=1
- name: Test
id: ggml-ci
run: |
GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \
bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-mac-vulkan:
runs-on: [self-hosted, macOS, ARM64]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-linux-intel-vulkan:
runs-on: [self-hosted, Linux, Intel]

View File

@@ -93,4 +93,5 @@ jobs:
export GGML_VK_DISABLE_F16=1
export GGML_VK_DISABLE_COOPMAT=1
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 4800
# test-backend-ops is too slow on llvmpipe, skip it
ctest -L main -E test-backend-ops --verbose --timeout 900

View File

@@ -318,7 +318,7 @@ jobs:
id: depends
run: |
sudo apt-get update
sudo apt-get install -y gcc-14 g++-14 build-essential glslc libvulkan-dev libssl-dev ninja-build
sudo apt-get install -y gcc-14 g++-14 build-essential glslc libvulkan-dev spirv-headers libssl-dev ninja-build
echo "CC=gcc-14" >> "$GITHUB_ENV"
echo "CXX=g++-14" >> "$GITHUB_ENV"
@@ -1001,22 +1001,14 @@ jobs:
steps:
- name: Install dependencies
run: |
sudo apt-get update
# Install necessary packages
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 cmake build-essential libssl-dev wget git-lfs
sudo apt-get update
sudo apt-get install -y libssl-dev
# Set gcc-14 and g++-14 as the default compilers
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-14 100
if ! which rustc; then
# Install Rust stable version
sudo apt-get install -y rustup
rustup install stable
rustup default stable
fi
git lfs install
- name: Check environment
@@ -1032,13 +1024,12 @@ jobs:
id: checkout
uses: actions/checkout@v6
# FIXME: Enable when ggml-org/ccache-action works on riscv64
# - name: ccache
# uses: ggml-org/ccache-action@v1.2.21
# with:
# key: ubuntu-cpu-riscv64-native
# evict-old-files: 1d
# save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: ccache
uses: ggml-org/ccache-action@afde29e5b5422e5da23cb1f639e8baecadeadfc3 # https://github.com/ggml-org/ccache-action/pull/1
with:
key: ubuntu-cpu-riscv64-native
evict-old-files: 1d
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build

View File

@@ -202,7 +202,7 @@ jobs:
sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libssl-dev
else
sudo apt-get update -y
sudo apt-get install -y gcc-14 g++-14 build-essential glslc libvulkan-dev libssl-dev ninja-build
sudo apt-get install -y gcc-14 g++-14 build-essential glslc libvulkan-dev spirv-headers libssl-dev ninja-build
echo "CC=gcc-14" >> "$GITHUB_ENV"
echo "CXX=g++-14" >> "$GITHUB_ENV"
fi

View File

@@ -84,41 +84,42 @@ jobs:
export ${{ matrix.extra_args }}
pytest -v -x -m "not slow"
server-cuda:
runs-on: [self-hosted, llama-server, Linux, NVIDIA]
name: server-cuda (${{ matrix.wf_name }})
strategy:
matrix:
build_type: [Release]
wf_name: ["GPUx1"]
include:
- build_type: Release
extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
wf_name: "GPUx1, backend-sampling"
fail-fast: false
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Build
id: cmake_build
run: |
cmake -B build -DGGML_SCHED_NO_REALLOC=ON
cmake --build build --config ${{ matrix.build_type }} -j $(sysctl -n hw.logicalcpu) --target llama-server
- name: Tests
id: server_integration_tests
if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }}
run: |
cd tools/server/tests
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
export ${{ matrix.extra_args }}
pytest -v -x -m "not slow"
# TODO: provision CUDA runner
# server-cuda:
# runs-on: [self-hosted, llama-server, Linux, NVIDIA]
#
# name: server-cuda (${{ matrix.wf_name }})
# strategy:
# matrix:
# build_type: [Release]
# wf_name: ["GPUx1"]
# include:
# - build_type: Release
# extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
# wf_name: "GPUx1, backend-sampling"
# fail-fast: false
#
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v6
# with:
# fetch-depth: 0
# ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
#
# - name: Build
# id: cmake_build
# run: |
# cmake -B build -DGGML_SCHED_NO_REALLOC=ON
# cmake --build build --config ${{ matrix.build_type }} -j $(sysctl -n hw.logicalcpu) --target llama-server
#
# - name: Tests
# id: server_integration_tests
# if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }}
# run: |
# cd tools/server/tests
# python3 -m venv venv
# source venv/bin/activate
# pip install -r requirements.txt
# export ${{ matrix.extra_args }}
# pytest -v -x -m "not slow"

View File

@@ -1,5 +1,21 @@
# collaborators can optionally add themselves here to indicate their availability for reviewing related PRs
# multiplie collaborators per item can be specified
# multiple collaborators per item can be specified
#
# ggml-org/ci : CISC, danbev, ggerganov, netrunnereve, ngxson, taronaeo
# ggml-org/ggml-cann : hipudding
# ggml-org/ggml-cuda : JohannesGaessler, am17an, IMbackK, ORippler
# ggml-org/ggml-hexagon : lhez, max-krasnyansky
# ggml-org/ggml-metal : ggerganov
# ggml-org/ggml-opencl : lhez, max-krasnyansky
# ggml-org/ggml-rpc : rgerganov
# ggml-org/ggml-sycl : arthw
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
# ggml-org/ggml-webgpu : reeselevine
# ggml-org/ggml-zdnn : taronaeo
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
# ggml-org/llama-mtmd : ngxson
# ggml-org/llama-server : ggerganov, ngxson, allozaur, angt, ServeurpersoCom
# ggml-org/llama-webui : allozaur
/.devops/*.Dockerfile @ngxson
/.github/actions/ @ggml-org/ci

View File

@@ -198,10 +198,19 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont
args_field = format.function_field + "." + args_field;
}
auto tools_parser = p.standard_json_tools(
format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls,
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped,
format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order);
auto tools_parser = p.eps();
if (format.section_start.empty() && !format.per_call_start.empty()) {
auto single_tool_parser = p.standard_json_tools(
format.per_call_start, format.per_call_end, inputs.tools, inputs.parallel_tool_calls,
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped,
format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order);
tools_parser = p.trigger_rule("tool-calls", p.one_or_more(single_tool_parser + p.space()));
} else {
tools_parser = p.standard_json_tools(
format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls,
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped,
format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order);
}
// Handle content wrappers if present
if (ctx.content && ctx.content->is_always_wrapped()) {

View File

@@ -308,19 +308,23 @@ struct analyze_tools : analyze_base {
private:
// Extract tool calling 'haystack' for further analysis and delegate further analysis based on format
void analyze_tool_calls(const analyze_reasoning & reasoning);
void analyze_tool_calls(const analyze_reasoning & reasoning, bool supports_parallel_tool_calls);
// Analyze format based on position of function and argument name in needle
void analyze_tool_call_format(const std::string & haystack,
const std::string & fun_name_needle,
const std::string & arg_name_needle,
const analyze_reasoning & reasoning);
const analyze_reasoning & reasoning,
bool supports_parallel_tool_calls);
// Analyze specifics of JSON native format (entire tool call is a JSON object)
void analyze_tool_call_format_json_native(const std::string & clean_haystack,
const std::string & fun_name_needle,
const std::string & arg_name_needle);
// Check if parallel calls in JSON native format array wrapped or tag wrapped
void analyze_json_native_parallel_calls();
// Analyze specifics of non-JSON native format (tags for function name or for function name and arguments)
void analyze_tool_call_format_non_json(const std::string & clean_haystack,
const std::string & fun_name_needle);

View File

@@ -558,7 +558,7 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl,
: analyze_base(tmpl) {
LOG_DBG(ANSI_ORANGE "Phase 3: Tool call analysis\n" ANSI_RESET);
analyze_tool_calls(reasoning);
analyze_tool_calls(reasoning, caps.supports_parallel_tool_calls);
if (format.mode != tool_format::NONE && format.mode != tool_format::JSON_NATIVE) {
if (caps.supports_parallel_tool_calls) {
@@ -577,7 +577,7 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl,
}
}
void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning) {
void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning, bool supports_parallel_tool_calls) {
json assistant_no_tools = json{
{ "role", "assistant" },
{ "content", ASSISTANT_MSG }
@@ -611,13 +611,14 @@ void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning) {
return;
}
analyze_tool_call_format(tool_section, FUN_FIRST, ARG_FIRST, reasoning);
analyze_tool_call_format(tool_section, FUN_FIRST, ARG_FIRST, reasoning, supports_parallel_tool_calls);
}
void analyze_tools::analyze_tool_call_format(const std::string & haystack,
const std::string & fun_name_needle,
const std::string & arg_name_needle,
const analyze_reasoning & reasoning) {
const analyze_reasoning & reasoning,
bool supports_parallel_tool_calls) {
if (fun_name_needle.empty() || arg_name_needle.empty() || haystack.empty()) {
return;
}
@@ -660,6 +661,9 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
if (format.mode == tool_format::JSON_NATIVE) {
analyze_tool_call_format_json_native(clean_haystack, fun_name_needle, arg_name_needle);
if (supports_parallel_tool_calls) {
analyze_json_native_parallel_calls();
}
} else {
analyze_tool_call_format_non_json(clean_haystack, fun_name_needle);
}
@@ -668,6 +672,42 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
format.per_call_end = trim_whitespace(format.per_call_end);
}
void analyze_tools::analyze_json_native_parallel_calls() {
json assistant_one_tool = json{
{ "role", "assistant" },
{ "content", "" },
{ "tool_calls", json::array({ first_tool_call }) }
};
json assistant_two_tools = json{
{ "role", "assistant" },
{ "content", "" },
{ "tool_calls", json::array({ first_tool_call, second_tool_call }) }
};
template_params params;
params.messages = json::array({ user_msg, assistant_one_tool });
params.tools = tools;
params.add_generation_prompt = false;
params.enable_thinking = true;
auto comparison = compare_variants(
*tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_two_tools }); });
if (!comparison) {
LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__);
return;
}
std::string & second_call = comparison->diff.right;
if (!format.section_start.empty() && second_call.find(format.section_start) != std::string::npos) {
format.per_call_start = format.section_start;
format.per_call_end = format.section_end;
format.section_start.clear();
format.section_end.clear();
}
}
void analyze_tools::analyze_tool_call_format_json_native(const std::string & clean_haystack,
const std::string & fun_name_needle,
const std::string & arg_name_needle) {

View File

@@ -676,7 +676,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() +
literal("\"") + tool_name(literal(name)) + literal("\"");
atomic(literal("\"") + tool_name(literal(name)) + literal("\""));
auto nested_args = literal("\"" + nested_args_field + "\"") + space() + literal(":") + space() +
tool_args(schema(json(), "tool-" + name + "-schema", params));
@@ -744,7 +744,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
auto tool_name_ = name_key_parser + space() + literal(":") + space() +
literal("\"") + tool_name(literal(name)) + literal("\"");
atomic(literal("\"") + tool_name(literal(name)) + literal("\""));
auto tool_args_ = args_key_parser + space() + literal(":") + space() +
tool_args(schema(json(), "tool-" + name + "-schema", params));

View File

@@ -287,8 +287,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
// reasoning budget sampler
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) {
// reasoning budget sampler (skip when budget is unlimited unless a lazy grammar is active, which needs rbudget for thinking-block suppression)
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty() && (params.grammar_lazy || params.reasoning_budget_tokens >= 0)) {
rbudget = common_reasoning_budget_init(
vocab,
params.reasoning_budget_start,

View File

@@ -10893,7 +10893,64 @@ class NemotronHModel(GraniteHybridModel):
self.gguf_writer.add_moe_latent_size(latent_size)
def set_vocab(self):
super().set_vocab()
# The NemotronH config uses pattern characters (e.g. '-') that may not
# be supported by the installed transformers version. AutoTokenizer
# internally calls AutoConfig which triggers this parsing failure.
# Using trust_remote_code=True to load the model's own config class.
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
# Pad vocab size (from Mamba2Model/GraniteHybridModel)
self.hparams["pad_vocab_size_multiple"] = 8 # Setting this here since GraniteHybridModel.set_vocab() isn't being invoked now.
# From Mamba2Model.set_vocab():
vocab_size = self.hparams["vocab_size"]
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
# ref: https://stackoverflow.com/a/17511341/22827863
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size
assert max(tokenizer.vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()
added_tokens_decoder = tokenizer.added_tokens_decoder
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
if not added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)
# From TextModel.set_vocab_gpt2():
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
# The tokenizer _does_ add a BOS token (via post_processor type
# TemplateProcessing) but does not set add_bos_token to true in the

View File

@@ -689,6 +689,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. (1.) |
| GGML_SYCL_GRAPH | OFF *(default)* \|ON *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. |
| GGML_SYCL_HOST_MEM_FALLBACK | ON *(default)* \|OFF *(Optional)* | Allow host memory fallback when device memory is full during quantized weight reorder. Enables inference to continue at reduced speed (reading over PCIe) instead of failing. Requires Linux kernel 6.8+. |
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |

View File

@@ -281,6 +281,12 @@ Use `GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F` environment variable to force use FP16
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. In Windows this setting is available in the NVIDIA control panel as `System Memory Fallback`.
### Peer Access
The environment variable `GGML_CUDA_P2P` can be set to enable peer-to-peer access between multiple GPUs, allowing them to transfer data directly rather than to go through system memory.
Requires driver support (usually restricted to workstation/datacenter GPUs).
May cause crashes or corrupted outputs for some motherboards and BIOS settings (e.g. IOMMU).
### Performance Tuning
The following compilation options are also available to tweak performance:
@@ -456,7 +462,8 @@ pacman -S git \
mingw-w64-ucrt-x86_64-gcc \
mingw-w64-ucrt-x86_64-cmake \
mingw-w64-ucrt-x86_64-vulkan-devel \
mingw-w64-ucrt-x86_64-shaderc
mingw-w64-ucrt-x86_64-shaderc \
mingw-w64-ucrt-x86_64-spirv-headers
```
Switch into the `llama.cpp` directory and build using CMake.
@@ -490,9 +497,11 @@ First, follow the official LunarG instructions for the installation and setup of
On Debian / Ubuntu, you can install the required dependencies using:
```sh
sudo apt-get install libvulkan-dev glslc
sudo apt-get install libvulkan-dev glslc spirv-headers
```
SPIRV-Headers (`spirv/unified1/spirv.hpp`) are required for the Vulkan backend and are **not** always pulled in by the Vulkan loader dev package alone. Other distros use names such as `spirv-headers` (Ubuntu / Debian / Arch), or `spirv-headers-devel` (Fedora / openSUSE). On Windows, the LunarG Vulkan SDKs `Include` directory already contains these headers.
#### Common steps
Second, after verifying that you have followed all of the SDK installation/setup steps, use this command to make sure before proceeding:

View File

@@ -130,6 +130,23 @@ Note:
- Adding a model-specific API or CLI is an anti-pattern in `libmtmd`. The goal of `libmtmd` is to provide an easy-to-use, model-agnostic library for multimodal pipeline.
- In most cases, `llama-mtmd-cli` should not be modified. If a model requires a specific prompt, either let the user provide it or bake it into the Jinja chat template.
## Tips and tricks
### Working with ggml_rope_ext
PyTorch implementations usually prefer explicitly calculating `freq_cis`/`sin`/`cos` components. However, in llama.cpp, most RoPE operations can be handled via `ggml_rope_ext`, which does not require a sin/cos matrix. This saves memory while allowing the GGML RoPE kernel to be fused with other ops.
However, since `ggml_rope_ext` only provides a subset of the RoPE implementations that models use, converting models from PyTorch to llama.cpp may require some creative adaptations.
For more information about `ggml_rope_ext`, please refer to the in-code documentation in `ggml.h`.
Examples:
- `libmtmd` implements 2D RoPE with `GGML_ROPE_TYPE_NORMAL` ordering by splitting the input tensor in half, applying `ggml_rope_ext` separately to each half, then joining them back together using `ggml_concat`.
- The [Kimi-K2.5](https://github.com/ggml-org/llama.cpp/pull/19170) vision encoder uses vision RoPE with interleaved frequencies. The weights must be permuted during conversion in order to reuse the `build_rope_2d()` function.
- [Gemma 4](https://github.com/ggml-org/llama.cpp/pull/21309) uses "proportional" RoPE. We employ a trick where `rope_freqs` is set to a very large value in the last dimensions to prevent those dimensions from being rotated. See the `Gemma4Model` class in `convert_hf_to_gguf.py`.
- Some models require scaling the input position. For example, `[0, 1, 2, ...]` becomes `[0, 0.5, 1, ...]`. In this case, you can provide the scaling via `freq_scale = 0.5f`.
- Some models use learned RoPE frequencies instead of relying on `powf(freq_base, -2.0 * i / n_dims)`. In this case, you can provide the learned frequencies via the `rope_freqs` tensor (corresponding to the `c` argument in `ggml_rope_ext`), then set `freq_base = 1.0f`. An important note is that `rope_freqs` in GGML is the **inverse** (`theta = pos[i] / rope_freqs`), so you may need to invert `rope_freqs` during conversion.
## GGUF specification
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md

View File

@@ -22,13 +22,13 @@ Legend:
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
@@ -46,7 +46,7 @@ Legend:
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@@ -84,10 +84,10 @@ Legend:
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROLL | ❌ | ❌ | ✅ | ✅ | | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROLL | ❌ | ❌ | ✅ | ✅ | | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROPE | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
@@ -116,6 +116,6 @@ Legend:
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |

File diff suppressed because it is too large Load Diff

View File

@@ -602,8 +602,8 @@ int main(int argc, char ** argv) {
int n_input = input_tokens.size();
if (n_input >= params.n_ctx) {
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
if (static_cast<uint32_t>(n_input) >= llama_n_ctx(ctx)) {
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, llama_n_ctx(ctx));
llama_free(ctx);
llama_model_free(model);
return 1;

View File

@@ -1,4 +1,11 @@
cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories.
# ref: https://cmake.org/cmake/help/latest/policy/CMP0194.html
# MSVC is not a valid assembler for the ASM language.
# Set to NEW to avoid a warning on CMake 4.1+ with MSVC.
if (POLICY CMP0194)
cmake_policy(SET CMP0194 NEW)
endif()
project("ggml" C CXX ASM)
### GGML Version
@@ -247,6 +254,7 @@ option(GGML_RPC "ggml: use RPC"
option(GGML_SYCL "ggml: use SYCL" OFF)
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
option(GGML_SYCL_HOST_MEM_FALLBACK "ggml: allow host memory fallback in SYCL reorder (requires kernel 6.8+)" ON)
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device")

View File

@@ -202,8 +202,11 @@ extern "C" {
// Common functions that may be obtained using ggml_backend_reg_get_proc_address
// AllReduce operation for tensor parallelism (meta backend)
typedef bool (*ggml_backend_allreduce_tensor_t)(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends);
// Context management and operations for faster communication between backends, used for tensor parallelism (meta backend)
typedef void * (*ggml_backend_comm_init_t)(ggml_backend_t * backends, size_t n_backends);
typedef void (*ggml_backend_comm_free_t)(void * comm_ctx);
typedef bool (*ggml_backend_comm_allreduce_tensor_t)(void * comm_ctx, struct ggml_tensor ** tensors);
// Split buffer type for tensor parallelism (old)
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);
// Set the number of threads for the backend
@@ -348,6 +351,53 @@ extern "C" {
// Set a callback to be called for each resulting node during graph compute
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
//
// Meta backend
//
#define GGML_BACKEND_META_MAX_DEVICES 16
enum ggml_backend_meta_split_axis {
// tensor split by tensor dimensions:
GGML_BACKEND_SPLIT_AXIS_0 = 0,
GGML_BACKEND_SPLIT_AXIS_1 = 1,
GGML_BACKEND_SPLIT_AXIS_2 = 2,
GGML_BACKEND_SPLIT_AXIS_3 = 3,
GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends
GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum
// for internal bookkeeping only:
GGML_BACKEND_SPLIT_AXIS_NONE = 98,
GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99,
};
GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis);
struct ggml_backend_meta_split_state {
enum ggml_backend_meta_split_axis axis;
// for tensors with axis >= 0 && axis < GGML_MAX_DIMS:
// - each device has a slice of the tensor along the split axis
// - most tensors have n_segments == 1 and a contiguous slice of the tensor data
// - some tensors have an inhomogenenous data layout along the split axis,
// those tensors are divided into segments which are each individually split across devices
// - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis,
// the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1],
// - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments
// that each need to be split individually across devices so that each device gets a slice of Q, K, and V
int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES];
uint32_t n_segments;
};
// function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible:
typedef struct ggml_backend_meta_split_state(*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata);
// create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this:
// TODO: this looks a bit strange - a backend API creates a device. I think we should try
// express this as a backend registry functionality instead
GGML_API ggml_backend_dev_t ggml_backend_meta_device(
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud);
//
// Utils
//

View File

@@ -6,9 +6,9 @@
extern "C" {
#endif
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 6
#define RPC_PROTO_PATCH_VERSION 1
#define RPC_PROTO_MAJOR_VERSION 4
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0
#ifdef __cplusplus
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");

View File

@@ -1773,8 +1773,32 @@ extern "C" {
int n_dims,
int mode);
// custom RoPE
// RoPE operations with extended options
// a is the input tensor to apply RoPE to, shape [n_embd, n_head, n_token]
// b is an int32 vector with size n_token
// c is freq factors (e.g. phi3-128k), (optional)
// mode can be GGML_ROPE_TYPE_NORMAL or NEOX; for MROPE and VISION mode, use ggml_rope_multi
//
// pseudo-code for computing theta:
// for i in [0, n_dims/2):
// theta[i] = b[i] * powf(freq_base, -2.0 * i / n_dims);
// theta[i] = theta[i] / c[i]; # if c is provided, divide theta by c
// theta[i] = rope_yarn(theta[i], ...); # note: theta = theta * freq_scale is applied here
//
// other params are used by YaRN RoPE scaling, these default values will disable YaRN:
// freq_scale = 1.0f
// ext_factor = 0.0f
// attn_factor = 1.0f
// beta_fast = 0.0f
// beta_slow = 0.0f
//
// example:
// (marking: c = cos, s = sin, 0 = unrotated)
// given a single head with size = 8 --> [00000000]
// GGML_ROPE_TYPE_NORMAL n_dims = 4 --> [cscs0000]
// GGML_ROPE_TYPE_NORMAL n_dims = 8 --> [cscscscs]
// GGML_ROPE_TYPE_NEOX n_dims = 4 --> [ccss0000]
// GGML_ROPE_TYPE_NEOX n_dims = 8 --> [ccccssss]
GGML_API struct ggml_tensor * ggml_rope_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -1790,6 +1814,36 @@ extern "C" {
float beta_fast,
float beta_slow);
// multi-dimensional RoPE, for Qwen-VL and similar vision models
// mode can be either VISION, MROPE, IMROPE, cannot be combined with NORMAL or NEOX
// sections specify how many dimensions to rotate in each section:
// section length is equivalent to number of cos/sin pairs, NOT the number of dims
// (i.e. sum of 4 sections are expected to be n_dims/2)
// last sections can be 0, means ignored
// all other options are identical to ggml_rope_ext
//
// important note:
// - NEOX ordering is automatically applied and cannot be disabled for MROPE and VISION
// if you need normal ordering, there are 2 methods:
// (1) split the tensor manually using ggml_view
// (2) permute the weight upon conversion
// - for VISION, n_dims must be head_size/2
//
// example M-RoPE:
// given sections = [t=4, y=2, x=2, 0]
// given a single head with size = 18 --> [000000000000000000]
// GGML_ROPE_TYPE_MROPE n_dims = 16 --> [ttttyyxxttttyyxx00] (cos/sin are applied in NEOX ordering)
// GGML_ROPE_TYPE_IMROPE n_dims = 16 --> [ttyxttyxttyxttyx00] (interleaved M-RoPE, still NEOX ordering)
// note: the theta for each dim is computed the same way as ggml_rope_ext, no matter the section
// in other words, idx used for theta: [0123456789... until n_dims/2], not reset for each section
//
// example vision RoPE:
// given sections = [y=4, x=4, 0, 0] (last 2 sections are ignored)
// given a single head with size = 8 --> [00000000]
// GGML_ROPE_TYPE_VISION n_dims = 4 --> [yyyyxxxx]
// other values of n_dims are untested and is undefined behavior
// note: unlike MROPE, the theta for each dim is computed differently for each section
// in other words, idx used for theta: [0123] for y section, then [0123] for x section
GGML_API struct ggml_tensor * ggml_rope_multi(
struct ggml_context * ctx,
struct ggml_tensor * a,

View File

@@ -2,6 +2,7 @@
#include "ggml-backend-impl.h"
#include "ggml.h"
#include "ggml-impl.h"
#include <assert.h>
#include <limits.h>
#include <stdarg.h>

View File

@@ -5,9 +5,6 @@
#include "ggml-alloc.h"
#include "ggml-cpp.h"
// TODO: tmp
#include "ggml-ext.h"
#include <algorithm>
#include <cassert>
#include <cmath>
@@ -1422,22 +1419,48 @@ struct ggml_backend_meta_context {
size_t max_tmp_size = 0;
size_t max_subgraphs = 0;
void * comm_ctx = nullptr;
ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr;
ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) {
const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev);
name = "Meta(";
std::vector<ggml_backend_t> simple_backends;
backend_configs.reserve(n_devs);
simple_backends.reserve(n_devs);
for (size_t i = 0; i < n_devs; i++) {
ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i);
if (i > 0) {
name += ",";
}
name += ggml_backend_dev_name(simple_dev);
backend_configs.emplace_back(ggml_backend_dev_init(simple_dev, params));
simple_backends.push_back(ggml_backend_dev_init(simple_dev, params));
backend_configs.emplace_back(simple_backends.back());
}
name += ")";
if (n_devs > 1) {
ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address(
ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init");
if (comm_init != nullptr) {
comm_ctx = comm_init(simple_backends.data(), simple_backends.size());
}
}
if (comm_ctx != nullptr) {
comm_allreduce = (ggml_backend_comm_allreduce_tensor_t)
ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg(
ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor");
GGML_ASSERT(comm_allreduce != nullptr);
}
}
~ggml_backend_meta_context() {
if (comm_ctx != nullptr) {
ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address(
ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free");
GGML_ASSERT(comm_free != nullptr);
comm_free(comm_ctx);
}
for (auto & bc : backend_configs) {
ggml_backend_free(bc.backend);
}
@@ -1848,20 +1871,15 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
if (n_backends > 1 && i < n_subgraphs - 1) {
bool backend_allreduce_success = false;
ggml_backend_allreduce_tensor_t allreduce_tensor = (ggml_backend_allreduce_tensor_t) ggml_backend_reg_get_proc_address(
ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_ctx->backend_configs[0].backend)), "ggml_backend_allreduce_tensor");
if (allreduce_tensor) {
std::vector<ggml_backend_t> backends;
backends.reserve(n_backends);
if (backend_ctx->comm_ctx) {
std::vector<ggml_tensor *> nodes;
nodes.reserve(n_backends);
for (size_t j = 0; j < n_backends; j++) {
auto & bcj = backend_ctx->backend_configs[j];
backends.push_back(bcj.backend);
ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main;
nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]);
}
backend_allreduce_success = allreduce_tensor(backends.data(), nodes.data(), n_backends);
backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data());
}
if (!backend_allreduce_success) {

View File

@@ -1030,6 +1030,8 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
GGML_ABORT("%s: failed to initialize context\n", __func__);
}
graph->uid = ggml_graph_next_uid();
// pass 1: assign backends to ops with pre-allocated inputs
for (int i = 0; i < graph->n_leafs; i++) {
struct ggml_tensor * leaf = graph->leafs[i];
@@ -1477,6 +1479,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
assert(graph_copy->size > graph_copy->n_leafs);
graph_copy->leafs[graph_copy->n_leafs++] = leaf;
}
// set ids for all splits
for (int i = 0; i < sched->n_splits; ++i) {
sched->splits[i].graph.uid = ggml_graph_next_uid();
}
}
static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {

View File

@@ -783,6 +783,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b));
const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4));
#if defined(__ARM_FEATURE_DOTPROD)
const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs);
const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16);
const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b));
@@ -794,15 +795,40 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b));
const int32x4_t p0 = vaddq_s32(
ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0),
ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0));
vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0),
vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0));
const int32x4_t p1 = vaddq_s32(
ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1),
ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1));
vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1),
vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1));
const int32x4_t sums = vpaddq_s32(p0, p1);
const int32x4_t sumi = vpaddq_s32(p0, p1);
#else
const int8x8_t q4_0_lo = vget_low_s8(q4_lo_0);
const int8x8_t q4_0_hi = vget_low_s8(q4_hi_0);
const int8x8_t q4_1_lo = vget_high_s8(q4_lo_0);
const int8x8_t q4_1_hi = vget_high_s8(q4_hi_0);
const int8x8_t q4_2_lo = vget_low_s8(q4_lo_1);
const int8x8_t q4_2_hi = vget_low_s8(q4_hi_1);
const int8x8_t q4_3_lo = vget_high_s8(q4_lo_1);
const int8x8_t q4_3_hi = vget_high_s8(q4_hi_1);
const int8x8_t q8_0_lo = vld1_s8(y[2*ib].qs);
const int8x8_t q8_0_hi = vld1_s8(y[2*ib].qs + 8);
const int8x8_t q8_1_lo = vld1_s8(y[2*ib].qs + 16);
const int8x8_t q8_1_hi = vld1_s8(y[2*ib].qs + 24);
const int8x8_t q8_2_lo = vld1_s8(y[2*ib+1].qs);
const int8x8_t q8_2_hi = vld1_s8(y[2*ib+1].qs + 8);
const int8x8_t q8_3_lo = vld1_s8(y[2*ib+1].qs + 16);
const int8x8_t q8_3_hi = vld1_s8(y[2*ib+1].qs + 24);
const int32x4_t sumi = (int32x4_t){
vaddvq_s32(ggml_nvfp4_dot8(q4_0_lo, q8_0_lo, q4_0_hi, q8_0_hi)),
vaddvq_s32(ggml_nvfp4_dot8(q4_1_lo, q8_1_lo, q4_1_hi, q8_1_hi)),
vaddvq_s32(ggml_nvfp4_dot8(q4_2_lo, q8_2_lo, q4_2_hi, q8_2_hi)),
vaddvq_s32(ggml_nvfp4_dot8(q4_3_lo, q8_3_lo, q4_3_hi, q8_3_hi)),
};
#endif
// Decode 4 UE4M3 scales to f32 and multiply with q8 scales
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
const float32x4_t nvsc = {
@@ -813,7 +839,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
};
const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1});
acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales);
acc = vfmaq_f32(acc, vcvtq_f32_s32(sumi), scales);
}
sumf = vaddvq_f32(acc);
#else

File diff suppressed because it is too large Load Diff

View File

@@ -306,6 +306,7 @@ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
#if !defined(__ARM_FEATURE_DOTPROD)
// NOTE: this fallback produces the same total sum as native vdotq_s32 but with different per-lane grouping — do not use when individual lane values matter.
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
@@ -319,6 +320,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
#endif // !defined(__ARM_FEATURE_DOTPROD)
static inline int32x4_t ggml_nvfp4_dot8(const int8x8_t q4_lo, const int8x8_t q8_lo,
const int8x8_t q4_hi, const int8x8_t q8_hi) {
const int16x8_t p_lo = vmull_s8(q4_lo, q8_lo);
const int16x8_t p_hi = vmull_s8(q4_hi, q8_hi);
const int32x4_t sum_lo = vpaddlq_s16(p_lo);
const int32x4_t sum_hi = vpaddlq_s16(p_hi);
return vaddq_s32(sum_lo, sum_hi);
}
#endif // defined(__ARM_NEON)
#ifdef __wasm_simd128__

View File

@@ -109,6 +109,96 @@ static void simd_gemm(
C += N;
}
}
#elif defined(GGML_SIMD) && defined(__riscv_v_intrinsic)
// RM accumulators + 1 B vector = RM + 1 <= 8 => RM <= 7
// Microkernel: C[RM x vl] += A[RM x K] * B[K x N]
template <int RM>
static inline void rvv_simd_gemm_ukernel(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int K, int N, size_t vl)
{
static_assert(RM >= 1 && RM <= 7, "RM must be 1..7 for LMUL=4");
vfloat32m4_t acc_0 = __riscv_vle32_v_f32m4(C + 0 * N, vl);
vfloat32m4_t acc_1, acc_2, acc_3, acc_4, acc_5, acc_6;
if constexpr (RM > 1) acc_1 = __riscv_vle32_v_f32m4(C + 1 * N, vl);
if constexpr (RM > 2) acc_2 = __riscv_vle32_v_f32m4(C + 2 * N, vl);
if constexpr (RM > 3) acc_3 = __riscv_vle32_v_f32m4(C + 3 * N, vl);
if constexpr (RM > 4) acc_4 = __riscv_vle32_v_f32m4(C + 4 * N, vl);
if constexpr (RM > 5) acc_5 = __riscv_vle32_v_f32m4(C + 5 * N, vl);
if constexpr (RM > 6) acc_6 = __riscv_vle32_v_f32m4(C + 6 * N, vl);
for (int kk = 0; kk < K; kk++) {
vfloat32m4_t b_0 = __riscv_vle32_v_f32m4(B + kk * N, vl);
acc_0 = __riscv_vfmacc_vf_f32m4(acc_0, A[0 * K + kk], b_0, vl);
if constexpr (RM > 1) acc_1 = __riscv_vfmacc_vf_f32m4(acc_1, A[1 * K + kk], b_0, vl);
if constexpr (RM > 2) acc_2 = __riscv_vfmacc_vf_f32m4(acc_2, A[2 * K + kk], b_0, vl);
if constexpr (RM > 3) acc_3 = __riscv_vfmacc_vf_f32m4(acc_3, A[3 * K + kk], b_0, vl);
if constexpr (RM > 4) acc_4 = __riscv_vfmacc_vf_f32m4(acc_4, A[4 * K + kk], b_0, vl);
if constexpr (RM > 5) acc_5 = __riscv_vfmacc_vf_f32m4(acc_5, A[5 * K + kk], b_0, vl);
if constexpr (RM > 6) acc_6 = __riscv_vfmacc_vf_f32m4(acc_6, A[6 * K + kk], b_0, vl);
}
__riscv_vse32_v_f32m4(C + 0 * N, acc_0, vl);
if constexpr (RM > 1) __riscv_vse32_v_f32m4(C + 1 * N, acc_1, vl);
if constexpr (RM > 2) __riscv_vse32_v_f32m4(C + 2 * N, acc_2, vl);
if constexpr (RM > 3) __riscv_vse32_v_f32m4(C + 3 * N, acc_3, vl);
if constexpr (RM > 4) __riscv_vse32_v_f32m4(C + 4 * N, acc_4, vl);
if constexpr (RM > 5) __riscv_vse32_v_f32m4(C + 5 * N, acc_5, vl);
if constexpr (RM > 6) __riscv_vse32_v_f32m4(C + 6 * N, acc_6, vl);
}
template <int RM>
static inline void rvv_simd_gemm_dispatch_tail(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int K, int N, int KN, int remaining_rows)
{
if constexpr (RM > 0) {
if (remaining_rows == RM) {
int64_t jj = 0;
for (; jj + KN <= N; jj += KN) {
rvv_simd_gemm_ukernel<RM>(C + jj, A, B + jj, K, N, KN);
}
if (jj < N) {
rvv_simd_gemm_ukernel<RM>(C + jj, A, B + jj, K, N, N - jj);
}
} else {
rvv_simd_gemm_dispatch_tail<RM - 1>(C, A, B, K, N, KN, remaining_rows);
}
}
}
static constexpr int GEMM_RM = 7;
// C[M x N] += A[M x K] * B[K x N]
static void simd_gemm(
float * GGML_RESTRICT C,
const float * GGML_RESTRICT A,
const float * GGML_RESTRICT B,
int M, int K, int N)
{
const int KN = (int)__riscv_vlenb();
int64_t ii = 0;
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
int64_t jj = 0;
for (; jj + KN <= N; jj += KN) {
rvv_simd_gemm_ukernel<GEMM_RM>(C + jj, A, B + jj, K, N, KN);
}
if (jj < N) {
rvv_simd_gemm_ukernel<GEMM_RM>(C + jj, A, B + jj, K, N, N - jj);
}
A += GEMM_RM * K;
C += GEMM_RM * N;
}
int remaining_rows = M - ii;
rvv_simd_gemm_dispatch_tail<GEMM_RM - 1>(C, A, B, K, N, KN, remaining_rows);
}
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop

View File

@@ -924,6 +924,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_F16> {
static constexpr int qr = 1;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q1_0> {
static constexpr int qk = QK1_0;
static constexpr int qr = QR1_0;
static constexpr int qi = QI1_0;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
static constexpr int qk = QK4_0;
@@ -1092,10 +1099,6 @@ struct ggml_cuda_device_info {
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
#ifdef GGML_USE_NCCL
ncclComm_t comms[GGML_CUDA_MAX_DEVICES];
#endif // GGML_USE_NCCL
};
const ggml_cuda_device_info & ggml_cuda_info();
@@ -1183,6 +1186,7 @@ struct ggml_cuda_graph {
std::vector<cudaGraphNode_t> nodes;
bool disable_due_to_gpu_arch = false;
bool warmup_complete = false;
uint64_t uid = 0;
struct node_properties {
ggml_tensor node;
void * node_src_data_ptrs[GGML_MAX_SRC];

View File

@@ -711,6 +711,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
@@ -767,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
@@ -822,6 +826,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
case GGML_TYPE_Q1_0:
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
@@ -843,6 +849,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cuda<float, nv_bfloat16>;
case GGML_TYPE_Q1_0:
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
@@ -864,6 +872,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F16:
return convert_unary_cuda<half, float>;
case GGML_TYPE_Q1_0:
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:

View File

@@ -1,5 +1,27 @@
#include "common.cuh"
static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q1_0 * x = (const block_q1_0 *) vx;
const float d = x[ib].d;
const int bit_index_0 = iqs;
const int bit_index_1 = iqs + 1;
const int byte_index_0 = bit_index_0 / 8;
const int bit_offset_0 = bit_index_0 % 8;
const int byte_index_1 = bit_index_1 / 8;
const int bit_offset_1 = bit_index_1 % 8;
// Extract bits: 1 = +d, 0 = -d (branchless)
const int bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1;
const int bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1;
v.x = (2*bit_0 - 1) * d;
v.y = (2*bit_1 - 1) * d;
}
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;

View File

@@ -179,6 +179,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q1_0:
get_rows_cuda_q<QK1_0, QR1_0, dequantize_q1_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q4_0:
get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);

View File

@@ -324,28 +324,22 @@ static ggml_cuda_device_info ggml_cuda_init() {
// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
for (int id = 0; id < info.device_count; ++id) {
ggml_cuda_set_device(id);
for (int id_other = 0; id_other < info.device_count; ++id_other) {
if (id == id_other) {
continue;
}
int can_access_peer;
CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
if (can_access_peer) {
CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
if (getenv("GGML_CUDA_P2P") != nullptr) {
for (int id = 0; id < info.device_count; ++id) {
ggml_cuda_set_device(id);
for (int id_other = 0; id_other < info.device_count; ++id_other) {
if (id == id_other) {
continue;
}
int can_access_peer;
CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
if (can_access_peer) {
CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
}
}
}
}
#ifdef GGML_USE_NCCL
int dev_ids[GGML_CUDA_MAX_DEVICES];
for (int id = 0; id < info.device_count; ++id) {
dev_ids[id] = id;
}
NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids));
#endif // GGML_USE_NCCL
return info;
}
@@ -1125,66 +1119,51 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte
/* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
};
bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends) {
#ifdef GGML_USE_NCCL
const int64_t ne = ggml_nelements(tensors[0]);
// FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0
// This then causes a crash in this function
if (ne == 0) {
return true;
}
for (size_t i = 0; i < n_backends; ++i) {
GGML_ASSERT(tensors[i] != nullptr);
GGML_ASSERT(ggml_nelements(tensors[i]) == ne);
GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i]));
}
struct ggml_backend_cuda_comm_context {
std::vector<ggml_backend_t> backends;
std::vector<ncclComm_t> comms;
const ggml_cuda_device_info info = ggml_cuda_info();
// For small tensors, simply reduce them as FP32.
// The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0.
if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) {
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream()));
~ggml_backend_cuda_comm_context() {
for (ncclComm_t comm : comms) {
NCCL_CHECK(ncclCommDestroy(comm));
}
NCCL_CHECK(ncclGroupEnd());
return true;
}
};
#endif // GGML_USE_NCCL
// For large tensors it's faster to compress them to BF16 for the reduction:
to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32);
to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
static void ggml_backend_cuda_comm_free(void * comm_ctx_v) {
#ifdef GGML_USE_NCCL
if (comm_ctx_v == nullptr) {
return;
}
ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v;
delete comm_ctx;
#else
GGML_UNUSED(comm_ctx_v);
#endif // GGML_USE_NCCL
}
ggml_cuda_pool_alloc<nv_bfloat16> tmp[GGML_CUDA_MAX_DEVICES];
for (size_t i = 0; i < n_backends; ++i) {
static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) {
#ifdef GGML_USE_NCCL
for (size_t i = 0; i < n_backends; i++) {
if (!ggml_backend_is_cuda(backends[i])) {
return nullptr;
}
}
ggml_backend_cuda_comm_context * ret = new ggml_backend_cuda_comm_context;
std::vector<int> dev_ids;
ret->backends.reserve(n_backends);
dev_ids.reserve(n_backends);
for (size_t i = 0; i < n_backends; i++) {
ret->backends.push_back(backends[i]);
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
tmp[i].pool = &cuda_ctx->pool();
tmp[i].alloc(ne);
ggml_cuda_set_device(i);
to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream());
CUDA_CHECK(cudaGetLastError());
dev_ids.push_back(cuda_ctx->device);
}
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream()));
}
NCCL_CHECK(ncclGroupEnd());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
ggml_cuda_set_device(i);
to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream());
CUDA_CHECK(cudaGetLastError());
}
return true;
ret->comms.resize(n_backends);
NCCL_CHECK(ncclCommInitAll(ret->comms.data(), n_backends, dev_ids.data()));
return ret;
#else
// If NCCL is installed it is used by default for optimal performance.
// However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package.
@@ -1197,7 +1176,76 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t
warning_printed = true;
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
GGML_UNUSED_VARS(backends, tensors, n_backends);
GGML_UNUSED_VARS(backends, n_backends);
return nullptr;
#endif // GGML_USE_NCCL
}
static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) {
#ifdef GGML_USE_NCCL
const int64_t ne = ggml_nelements(tensors[0]);
// FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0
// This then causes a crash in this function
if (ne == 0) {
return true;
}
GGML_ASSERT(comm_ctx_v != nullptr);
ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v;
const size_t n_backends = comm_ctx->backends.size();
for (size_t i = 0; i < n_backends; ++i) {
GGML_ASSERT(tensors[i] != nullptr);
GGML_ASSERT(ggml_nelements(tensors[i]) == ne);
GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i]));
}
// For small tensors, simply reduce them as FP32.
// The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0.
if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) {
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comm_ctx->comms[i], cuda_ctx->stream()));
}
NCCL_CHECK(ncclGroupEnd());
return true;
}
// For large tensors it's faster to compress them to BF16 for the reduction:
to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32);
to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
ggml_cuda_pool_alloc<nv_bfloat16> tmp[GGML_CUDA_MAX_DEVICES];
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
tmp[i].pool = &cuda_ctx->pool();
tmp[i].alloc(ne);
ggml_cuda_set_device(cuda_ctx->device);
to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream());
CUDA_CHECK(cudaGetLastError());
}
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, comm_ctx->comms[i], cuda_ctx->stream()));
}
NCCL_CHECK(ncclGroupEnd());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
ggml_cuda_set_device(cuda_ctx->device);
to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream());
CUDA_CHECK(cudaGetLastError());
}
return true;
#else
GGML_UNUSED_VARS(comm_ctx_v, tensors);
return false;
#endif // GGML_USE_NCCL
}
@@ -3060,6 +3108,15 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
if (cgraph->uid != 0 &&
cgraph->uid == graph->uid) {
GGML_LOG_DEBUG("CUDA Graph id %zu reused\n", cgraph->uid);
GGML_ASSERT((int)graph->node_props.size() == cgraph->n_nodes);
return false;
}
graph->uid = cgraph->uid;
// Check if the graph size has changed
if ((int)graph->node_props.size() != cgraph->n_nodes) {
res = true;
@@ -4783,6 +4840,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
switch (a->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -4820,6 +4878,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_F32:
case GGML_TYPE_BF16:
case GGML_TYPE_I32:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -5220,8 +5279,14 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
GGML_UNUSED(reg);
if (strcmp(name, "ggml_backend_allreduce_tensor") == 0) {
return (void *)ggml_backend_cuda_allreduce_tensor;
if (strcmp(name, "ggml_backend_comm_init") == 0) {
return (void *)ggml_backend_cuda_comm_init;
}
if (strcmp(name, "ggml_backend_comm_free") == 0) {
return (void *)ggml_backend_cuda_comm_free;
}
if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) {
return (void *)ggml_backend_cuda_comm_allreduce_tensor;
}
if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
return (void *)ggml_backend_cuda_split_buffer_type;

View File

@@ -5,6 +5,9 @@
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
case GGML_TYPE_Q1_0:
mul_mat_q_case<GGML_TYPE_Q1_0>(ctx, args, stream);
break;
case GGML_TYPE_Q4_0:
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
break;
@@ -270,6 +273,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
bool mmq_supported;
switch (type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:

View File

@@ -57,6 +57,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
switch (type_x) {
case GGML_TYPE_Q1_0:
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return MMQ_Q8_1_DS_LAYOUT_DS4;
@@ -185,6 +187,7 @@ static constexpr __device__ int get_mmq_y_device() {
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
switch (type) {
case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
@@ -229,6 +232,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
@@ -302,6 +306,87 @@ static constexpr __device__ int mmq_get_nwarps_device() {
// ------------------------------------------------------------
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0;
constexpr int threads_per_row = blocks_per_iter * QI1_0;
constexpr int nrows = warp_size / threads_per_row;
constexpr int scale_entries_per_block = QK1_0 / QK8_1;
constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;
const int txi = threadIdx.x % threads_per_row;
const int kbx = txi / QI1_0;
const int kqsx = txi % QI1_0;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
if (need_check) {
i = min(i, i_max);
}
const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx;
const int qs_offset = 4*kqsx;
const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) |
(bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24);
int unpacked_bytes[8];
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int shift = j * 4;
const int bits4 = (qs0 >> shift) & 0x0F;
const int b0 = (bits4 & 0x01) ? 1 : -1;
const int b1 = (bits4 & 0x02) ? 1 : -1;
const int b2 = (bits4 & 0x04) ? 1 : -1;
const int b3 = (bits4 & 0x08) ? 1 : -1;
unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
}
const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
#pragma unroll
for (int j = 0; j < 8; ++j) {
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j];
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
const int ksx = threadIdx.x % scale_entries_per_row;
const int scale_block = ksx / scale_entries_per_block;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
if (need_check) {
i = min(i, i_max);
}
const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block;
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d;
#else
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
@@ -3290,6 +3375,14 @@ static __device__ __forceinline__ void mmq_write_back_mma(
template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
struct mmq_type_traits;
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;

View File

@@ -9,6 +9,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1;
case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
@@ -36,6 +37,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
@@ -886,6 +888,12 @@ static void mul_mat_vec_q_switch_type(
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const int ids_stride, cudaStream_t stream) {
switch (type_x) {
case GGML_TYPE_Q1_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q1_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q4_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,

View File

@@ -32,6 +32,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
TYPES_MMQ = [
"GGML_TYPE_Q1_0",
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_Q1_0);

View File

@@ -106,6 +106,9 @@ static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism
#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block
#define VDR_Q4_0_Q8_1_MMVQ 2
#define VDR_Q4_0_Q8_1_MMQ 4
@@ -669,6 +672,51 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
return d6 * sumf_d;
}
static __device__ __forceinline__ float vec_dot_q1_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_q1_0 * bq1_0 = (const block_q1_0 *) vbq + kbx;
// Q1_0: 128 elements with ONE scale
// Q8_1: 32 elements per block with individual scales
// iqs selects which of the 4 chunks of 32 elements to process (0-3)
const float d1 = bq1_0->d;
// Process only the chunk specified by iqs
const block_q8_1 * bq8_1_chunk = bq8_1 + iqs;
// Load 32 bits (4 bytes) for this chunk from Q1_0
const int offset = iqs * 4;
const int v = bq1_0->qs[offset + 0] | (bq1_0->qs[offset + 1] << 8) |
(bq1_0->qs[offset + 2] << 16) | (bq1_0->qs[offset + 3] << 24);
// Unpack 32 bits into 32 signed values (-1 or +1)
int vi_bytes[8];
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int shift = j * 4;
const int bits4 = (v >> shift) & 0x0F;
const int b0 = (bits4 & 0x01) ? 1 : -1;
const int b1 = (bits4 & 0x02) ? 1 : -1;
const int b2 = (bits4 & 0x04) ? 1 : -1;
const int b3 = (bits4 & 0x08) ? 1 : -1;
vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
}
// Compute dot product for this 32-element chunk
int sumi = 0;
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int u = get_int_b4(bq8_1_chunk->qs, j);
sumi = ggml_cuda_dp4a(vi_bytes[j], u, sumi);
}
// Apply Q1_0's single scale and this chunk's Q8_1 scale
const float d8 = __low2float(bq8_1_chunk->ds);
return d1 * d8 * sumi;
}
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {

View File

@@ -1,56 +0,0 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
// This is a "staging" header for new ggml API
// It is not publicly available and it should not be used by 3rd party projects
//
// When the API matures enough, it will be moved to the official public API
//
// Meta backend
//
#define GGML_BACKEND_META_MAX_DEVICES 16
enum ggml_backend_meta_split_axis {
// tensor split by tensor dimensions:
GGML_BACKEND_SPLIT_AXIS_0 = 0,
GGML_BACKEND_SPLIT_AXIS_1 = 1,
GGML_BACKEND_SPLIT_AXIS_2 = 2,
GGML_BACKEND_SPLIT_AXIS_3 = 3,
GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends
GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum
// for internal bookkeeping only:
GGML_BACKEND_SPLIT_AXIS_NONE = 98,
GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99,
};
GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis);
struct ggml_backend_meta_split_state {
enum ggml_backend_meta_split_axis axis;
// for tensors with axis >= 0 && axis < GGML_MAX_DIMS:
// - each device has a slice of the tensor along the split axis
// - most tensors have n_segments == 1 and a contiguous slice of the tensor data
// - some tensors have an inhomogenenous data layout along the split axis,
// those tensors are divided into segments which are each individually split across devices
// - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis,
// the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1],
// - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments
// that each need to be split individually across devices so that each device gets a slice of Q, K, and V
int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES];
uint32_t n_segments;
};
// function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible:
typedef struct ggml_backend_meta_split_state(*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata);
// create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this:
// TODO: this looks a bit strange - a backend API creates a device. I think we should try
// express this as a backend registry functionality instead
GGML_API ggml_backend_dev_t ggml_backend_meta_device(
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud);

View File

@@ -47,6 +47,7 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-queue.c
hmx-matmul-ops.c
)

View File

@@ -31,6 +31,14 @@ static inline uint64_t hex_get_pktcnt() {
return pktcnt;
}
static inline uint32_t hex_ceil_pow2(uint32_t x) {
if (x <= 1) { return 1; }
int p = 2;
x--;
while (x >>= 1) { p <<= 1; }
return p;
}
static inline size_t hmx_ceil_div(size_t num, size_t den) {
return (num + den - 1) / den;
}
@@ -73,8 +81,7 @@ static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride,
#define HEX_L2_LINE_SIZE 64
#define HEX_L2_FLUSH_SIZE (128 * 1024)
static inline void hex_l2flush(void * addr, size_t size)
{
static inline void hex_l2flush(void * addr, size_t size) {
if (size > HEX_L2_FLUSH_SIZE) {
qurt_mem_cache_clean((qurt_addr_t) 0, 0, QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE);
} else {
@@ -89,4 +96,8 @@ static inline void hex_l2flush(void * addr, size_t size)
}
}
static inline void hex_pause() {
asm volatile(" pause(#255)\n");
}
#endif /* HEX_UTILS_H */

View File

@@ -16,14 +16,16 @@
#include "ggml-common.h"
#include "hex-dma.h"
#include "worker-pool.h"
#include "hvx-utils.h"
#include "hvx-dump.h"
#include "worker-pool.h"
#include "htp-ctx.h"
#include "htp-ops.h"
#include "hmx-utils.h"
#include "hmx-ops.h"
#include "hmx-utils.h"
#include "hmx-queue.h"
#include "hmx-profile.h"
static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
@@ -47,7 +49,8 @@ static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = {
0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128,
8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
16*128, 17*128, 18*128, 19*128, 20*128, 21*128, 22*128, 23*128,
24*128, 25*128, 26*128, 27*128, 28*128, 29*128, 30*128, 31*128
};
// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes
@@ -109,36 +112,45 @@ static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) {
return false;
}
// Search for optimal (mc, nc) chunk sizes that maximize mc * nc within VTCM budget.
// Search for optimal (mc, nc) chunk sizes within VTCM budget.
//
// Cost model: total = nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead
// per_n_cost: bytes per nc column (weight + scratch buffers)
// per_m_cost: bytes per mc row (activation)
// per_mn_cost: bytes per mc*nc element (output)
// overhead: fixed bytes (scales 256B, eye_tile 2048B, etc.)
// VTCM model: nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead
//
// Minimize ceil(m/mc) * m_block_cost + ceil(n/nc) * n_block_cost.
// All matmul paths repeat weight processing per M-block and activation loading
// per N-block, so discrete block counts drive total overhead.
// Tie-break: when cost is equal, prefer larger mc * nc.
//
// Caller-provided coefficients:
// m_block_cost: penalty per extra M-block (weight redundancy, scales with n).
// n_block_cost: penalty per extra N-block (activation redundancy, scales with m).
//
// Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max.
// Returns 0 on success, -1 if VTCM is insufficient.
static int hmx_compute_chunks(
size_t vtcm_total, size_t overhead,
size_t per_n_cost, size_t per_m_cost, size_t per_mn_cost,
int m, int n,
size_t *m_chunk_out, size_t *n_chunk_out,
size_t *total_out)
{
static int hmx_compute_chunks(size_t vtcm_total,
size_t overhead,
size_t per_n_cost,
size_t per_m_cost,
size_t per_mn_cost,
int m,
int n,
size_t m_block_cost,
size_t n_block_cost,
size_t * m_chunk_out,
size_t * n_chunk_out,
size_t * total_out) {
if (m <= 0 || n <= 0) return -1;
if (vtcm_total <= overhead) return -1;
if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1;
const size_t usable = vtcm_total - overhead;
size_t best_mn = 0, best_m = 0, best_n = 0;
size_t best_cost = SIZE_MAX;
size_t best_mn = 0;
size_t best_m = 0, best_n = 0;
const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS);
for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) {
// Early exit: if nc * m_max cannot beat best, smaller nc won't either
if (nc * hex_align_down((size_t)m, HMX_FP16_TILE_N_ROWS) <= best_mn)
break;
size_t n_fixed = 0, ncmn = 0, mc_denom = 0;
if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue;
if (n_fixed >= usable) goto next_nc;
@@ -152,10 +164,19 @@ static int hmx_compute_chunks(
mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS);
mc = hex_smin(mc, (size_t)m);
if (mc > 0 && mc * nc > best_mn) {
best_mn = mc * nc;
best_m = mc;
best_n = nc;
if (mc == 0) {
goto next_nc;
}
size_t mblocks = ((size_t) m + mc - 1) / mc;
size_t nblocks = ((size_t) n + nc - 1) / nc;
size_t cost = mblocks * m_block_cost + nblocks * n_block_cost;
size_t mn = mc * nc;
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
best_cost = cost;
best_mn = mn;
best_m = mc;
best_n = nc;
}
}
@@ -233,7 +254,7 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
// q4x4x2 stores two int4 values per byte. Keep only the selected nibble.
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
// Shuffle before LUT
v_quants = Q6_Vb_vshuff_Vb(v_quants);
@@ -257,7 +278,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
// Load all 128 packed bytes (4 contiguous 32-byte groups)
HVX_Vector vq = hvx_vmemu(packed_128);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
// Shuffle before LUT
@@ -277,10 +298,8 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
// Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter
out[0] = v_lo; // group0 already in [0:63]
out[1] = Q6_V_vror_VR(v_lo, 64); // group1 rotated to [0:63]
out[2] = v_hi; // group2 already in [0:63]
out[3] = Q6_V_vror_VR(v_hi, 64); // group3 rotated to [0:63]
out[0] = v_lo; // group0 already in [0:63]
out[1] = v_hi; // group2 already in [0:63]
}
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
@@ -384,8 +403,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
size_t row_stride, int weight_type,
int start_tile, int end_tile) {
const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2);
const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS;
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block;
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
@@ -398,47 +418,46 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes)
for (int t = start_tile; t < end_tile; ) {
int ct = t / n_k_tiles; // column tile index
int kt = t % n_k_tiles; // K tile index
unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index
unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index
for (unsigned t = start_tile; t < end_tile; ) {
if (kt >= n_k_tiles) { kt = 0; ct++; }
// --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row ---
if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) &&
((t + 3) / n_k_tiles == ct)) {
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
bool upper = (sub_blk_base >= 4);
int packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
+ sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales
// --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
bool upper = (sub_blk_base >= 4);
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
+ sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales
__fp16 *tile_bases[4];
for (int g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
HVX_Vector v_off = v_scat_base;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
int row1 = row0 + 1;
const uint8_t *r0 = vtcm_src + row0 * row_stride;
const uint8_t *r1 = vtcm_src + row1 * row_stride;
HVX_Vector v0[4], v1[4];
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
HVX_Vector v0[2];
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
if (row1 < n_cols) {
dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt, v1);
} else {
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
}
for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); }
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); }
r0 = vtcm_src + row_offset; row_offset += row_stride;
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); }
t += 4;
t += 4; kt += 4;
continue;
}
@@ -495,20 +514,19 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
// --- Single-tile fallback ---
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) {
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
bool upper = (sub_blk >= 4);
int byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
if (is_q4) {
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
bool upper = (sub_blk >= 4);
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
HVX_Vector v_off = v_scat_base; // reset to column 0
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
int row1 = row0 + 1;
const uint8_t *r0 = vtcm_src + row0 * row_stride;
const uint8_t *r1 = vtcm_src + row1 * row_stride;
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(
r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
@@ -585,7 +603,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
}
(void) *(volatile HVX_Vector *)(tile_base);
}
++t;
++t; ++kt;
}
// Drain HVX scatter write buffer: a vmem load on the same HW thread retires
@@ -653,9 +671,13 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
// --- End x4x2 dequantizers ---
// requires external HMX lock
static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const __fp16 *weight, const __fp16 *scales,
static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales,
int n_row_tiles, int n_col_tiles, int n_dot_tiles) {
hmx_set_output_scales(scales);
__builtin_assume(n_row_tiles > 0);
__builtin_assume(n_col_tiles > 0);
__builtin_assume(n_dot_tiles > 0);
Q6_bias_mxmem2_A((void *)scales);
for (int r = 0; r < n_row_tiles; ++r) {
for (int c = 0; c < n_col_tiles; ++c) {
@@ -665,16 +687,55 @@ static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const
const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
for (int k = 0; k < n_dot_tiles; ++k) {
int offset = k * HMX_FP16_TILE_N_ELMS;
hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset);
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
row_tiles += HMX_FP16_TILE_N_ELMS;
col_tiles += HMX_FP16_TILE_N_ELMS;
}
__fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS;
hmx_consume_accumulator_fp16(out_tile);
Q6_mxmem_AR_after_hf(out_tile, 0);
}
}
}
// --- Async HMX matmul job (for pipeline overlap) ---
typedef struct {
__fp16 * output;
const __fp16 * activation;
const __fp16 * weight;
const __fp16 * scales;
uint32_t n_row_tiles;
uint32_t n_col_tiles;
uint32_t n_dot_tiles;
} hmx_matmul_job_t;
static void hmx_matmul_worker_fn(void * data) {
hmx_matmul_job_t * job = (hmx_matmul_job_t *) data;
FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles);
core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles);
}
static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
__fp16 * output,
const __fp16 * activation,
const __fp16 * weight,
const __fp16 * scales,
int n_row_tiles,
int n_col_tiles,
int n_dot_tiles) {
job->output = output;
job->activation = activation;
job->weight = weight;
job->scales = scales;
job->n_row_tiles = n_row_tiles;
job->n_col_tiles = n_col_tiles;
job->n_dot_tiles = n_dot_tiles;
}
// --- End async HMX matmul job ---
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) {
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
@@ -832,12 +893,13 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0;
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
// FP16 weight: interleave and activation load have similar per-element cost.
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256,
/*per_n=*/3 * vec_dot_size,
/*per_m=*/group_size * vec_dot_size + f32_scratch_per_m,
/*per_mn=*/sizeof(__fp16),
params->m, params->n,
&m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
/*per_n=*/3 * vec_dot_size,
/*per_m=*/group_size * vec_dot_size + f32_scratch_per_m,
/*per_mn=*/sizeof(__fp16), params->m, params->n,
/*m_block_cost=*/(size_t) params->n,
/*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__);
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
}
@@ -1006,13 +1068,15 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0;
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
// FP16 weight: interleave and activation load have similar per-element cost.
if (hmx_compute_chunks(vtcm_budget,
/*overhead=*/ 256,
/*per_n=*/ 3 * vec_dot_size, // W + S0 + S1
/*per_m=*/ vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch
/*per_mn=*/ sizeof(__fp16), // O
m, n,
&m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
/*overhead=*/256,
/*per_n=*/3 * vec_dot_size, // W + S0 + S1
/*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch
/*per_mn=*/sizeof(__fp16), // O
m, n,
/*m_block_cost=*/(size_t) n,
/*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
return -1;
}
@@ -1157,6 +1221,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m,
int k, int n, int w_type);
#define FALLBACK_TO_STANDARD 1
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
const uint8_t *restrict permuted_weight, int m, int k, int n,
int weight_type) {
@@ -1169,9 +1235,12 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
// for large m, k (e.g. prefill FFN Down), use out-stationary version
if (m >= 128 && k > n && n > 1024) {
FARF(MEDIUM, "hmx_matmul_qk: OUT-STATIONARY path m=%d k=%d n=%d type=%d (K_BLOCK=512, %d K-iters with fp16 intermediate)",
m, k, n, weight_type, (k + 511) / 512);
return mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type);
int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type);
if (rc != FALLBACK_TO_STANDARD) {
return rc; // 0 success, -1 error
}
FARF(MEDIUM, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n);
// fall through to standard path
}
size_t row_stride = get_x4x2_row_stride(weight_type, k);
@@ -1197,9 +1266,10 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
}
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256,
per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost,
m, n, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
// Quantized weight: dequant ~1.5x more expensive per element than activation load.
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost, m, n,
/*m_block_cost=*/(size_t) n * 3,
/*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d pipe=%d budget=%zu)",
__func__, m, k, n, use_pipeline, vtcm_budget);
return -1;
@@ -1256,9 +1326,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols,
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
if (!use_pipeline) {
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
// transfer activation matrix chunk into VTCM
size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
@@ -1318,20 +1387,22 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
TIMER_STOP(output_store);
}
}
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
} else {
// 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D)
// stage B and D (dequantize and store) are expected to be on the critical path
// HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D).
// A --> B: vtcm_qweight, 1 buffer
// B --> C: vtcm_weight0/vtcm_weight1, 2 buffers
// C --> D: vtcm_output0/vtcm_output1, 2 buffers
//
// LD ||A3| | B3 ||
// MM || C2 ||
// ST || D1 | ||
// Async timeline (C overlaps B+D):
// main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2]
// HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████]
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
@@ -1352,31 +1423,34 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
}
// prologue: B0, A1, C0, B1
// prologue: B0, A1, submit C0 (async), B1 (overlaps C0)
{
// B0
// B0: wait for DMA, dequant weight chunk 0
dma_queue_pop(ctx->dma[0]);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type);
// A1
// A1: issue DMA for weight chunk 1
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
if (1 < n_chunk_cnt) {
const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride;
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
}
// C0
core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[0], vtcm_scales,
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
// submit C0 (non-blocking — HMX worker executes in parallel)
hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation,
(__fp16 *) vtcm_weight_bufs[0], vtcm_scales,
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0]));
// B1
// B1: DMA pop + dequant (runs in parallel with C0 on HMX worker)
if (1 < n_chunk_cnt) {
dma_queue_pop(ctx->dma[0]);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type);
}
}
// main loop
// main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1})
for (int i = 0; i < n_chunk_cnt; ++i) {
const size_t nc = i * n_chunk_n_cols;
const size_t nc_p1 = nc + 1 * n_chunk_n_cols;
@@ -1386,36 +1460,41 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols);
const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols);
// issue A_{i+2}
// issue A_{i+2}: DMA push (non-blocking)
if (i + 2 < n_chunk_cnt) {
const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride;
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
}
// wait for HMX (C_{i}) -- C_{i} is done
// wait C_i: block until prologue/previous C completes
hmx_queue_pop(ctx->hmx_queue);
// result of B_{i+1} (input of C_{i+1}) should be ready now
// issue C_{i+1}
// submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below)
// job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's
// counterpart — and (i+1)%2 was last used by C_{i-1} which completed
// before C_i was submitted.
if (i + 1 < n_chunk_cnt) {
core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[(i + 1) % 2], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], vtcm_scales,
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2],
(__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2],
vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS),
hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2]));
}
// compute D_{i}
// D_i: store output (multi-thread HVX, parallel with C_{i+1})
float *output_chunk = dst + (mr * n + nc);
transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n);
// wait for DMA (A_{i+2}), compute B_{i+2}
// B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1})
if (i + 2 < n_chunk_cnt) {
dma_queue_pop(ctx->dma[0]);
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type);
}
}
}
}
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
hmx_queue_suspend(ctx->hmx_queue);
}
TIMER_STOP(total);
@@ -1434,10 +1513,13 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
}
// C += AB
void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp16 *col_scales, const __fp16 *eye_tile,
void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile,
int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) {
__builtin_assume(n_row_tiles > 0);
__builtin_assume(n_col_tiles > 0);
__builtin_assume(n_dot_tiles > 0);
hmx_set_output_scales(col_scales);
Q6_bias_mxmem2_A((void *)col_scales);
for (int i = 0; i < n_row_tiles; ++i) {
for (int j = 0; j < n_col_tiles; ++j) {
@@ -1448,15 +1530,17 @@ void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp
__fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS;
if (!zero_init) {
hmx_load_tile_pair_fp16(accum_tile, eye_tile);
Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047);
Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047);
}
for (int k = 0; k < n_dot_tiles; ++k) {
int offset = k * HMX_FP16_TILE_N_ELMS;
hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset);
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
row_tiles += HMX_FP16_TILE_N_ELMS;
col_tiles += HMX_FP16_TILE_N_ELMS;
}
hmx_consume_accumulator_fp16(accum_tile);
Q6_mxmem_AR_after_hf(accum_tile, 0);
}
}
}
@@ -1540,12 +1624,41 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
const size_t vtcm_budget = ctx->vtcm_size;
const size_t M_BLOCK_SIZE = 512;
const size_t N_BLOCK_SIZE = 512;
const size_t K_BLOCK_SIZE = 512;
const size_t K_BLOCK_SIZE = 1024;
// Compute precise buffer sizes
// Fallback: if k doesn't need K-blocking, out-stationary has no advantage
const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE;
if (k_iters_check <= 1) {
FARF(MEDIUM, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k);
return FALLBACK_TO_STANDARD;
}
// Dynamic M,N search via hmx_compute_chunks
const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE);
const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32)
+ K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles)
const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant)
+ K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles)
const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary)
// Alignment margin: hex_align_up can add up to 2047 bytes per buffer;
// scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin
const size_t align_margin = 4 * HMX_FP16_TILE_SIZE;
const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment
size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used;
// Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost.
// From profiling: wt_dequant per element ≈ 1.5× activation load per element.
// m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive).
// n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper).
const size_t m_block_cost = (size_t) n * 3;
const size_t n_block_cost = (size_t) m * 2;
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE,
&N_BLOCK_SIZE, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
return -1;
}
// Compute precise buffer sizes from searched M,N and fixed K
const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
@@ -1554,7 +1667,8 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256;
if (total_vtcm > vtcm_budget) {
FARF(HIGH, "%s: VTCM too small: need %zu have %zu (m=%d k=%d n=%d)", __func__, total_vtcm, vtcm_budget, m, k, n);
FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm,
vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE);
return -1;
}
@@ -1568,8 +1682,8 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget);
FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", __func__, m, k, n, weight_type,
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", __func__, m, k, n, weight_type,
M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
// initialize eye tile (32x32 identity matrix)
{

View File

@@ -0,0 +1,158 @@
#pragma clang diagnostic ignored "-Wunused-function"
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>
#include <qurt_thread.h>
#include <qurt_futex.h>
#include <HAP_compute_res.h>
#include "hmx-queue.h"
#define QURT_LOWEST_PRIO (254)
static inline void hmx_lock(struct hmx_queue *q)
{
if (!q->hmx_locked) {
HAP_compute_res_hmx_lock(q->hap_rctx);
q->hmx_locked = true;
}
}
static inline void hmx_unlock(struct hmx_queue *q)
{
if (q->hmx_locked) {
HAP_compute_res_hmx_unlock(q->hap_rctx);
q->hmx_locked = false;
}
}
static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) {
unsigned int ir = atomic_load(&q->idx_read);
while (ir != atomic_load(&q->idx_write)) {
struct hmx_queue_desc *d = &q->desc[ir];
if (!d->done) {
FARF(HIGH, "hmx-queue-process: ir %u func %p data %p", ir, d->func, d->data);
enum hmx_queue_signal sig = (enum hmx_queue_signal) (unsigned int) d->func;
switch (sig) {
case HMX_QUEUE_NOOP: /* noop */; break;
case HMX_QUEUE_KILL: *killed = true; break;
case HMX_QUEUE_SUSPEND: hmx_unlock(q); break;
default:
hmx_lock(q);
d->func(d->data);
break;
}
atomic_fetch_add(&d->done, 1);
}
ir = (ir + 1) & q->idx_mask;
atomic_store(&q->idx_read, ir);
}
}
static void hmx_queue_thread(void * arg) {
struct hmx_queue * q = (struct hmx_queue *) arg;
FARF(HIGH, "hmx-queue-thread: started");
bool killed = false;
unsigned int poll_cnt = HMX_QUEUE_POLL_COUNT;
unsigned int prev_seqn = 0;
while (!killed) {
unsigned int seqn = atomic_load(&q->seqn);
if (seqn == prev_seqn) {
if (--poll_cnt) { hex_pause(); continue; }
FARF(HIGH, "hmx-queue-thread: sleeping");
qurt_futex_wait(&q->seqn, prev_seqn);
continue;
}
prev_seqn = seqn;
poll_cnt = HMX_QUEUE_POLL_COUNT;
FARF(HIGH, "hmx-queue-thread: new work");
hmx_queue_process(q, &killed);
}
FARF(HIGH, "hmx-queue-thread: stopped");
}
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx) {
capacity = hex_ceil_pow2(capacity);
struct hmx_queue * q = (struct hmx_queue *) memalign(32, sizeof(struct hmx_queue));
if (q == NULL) {
FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__);
return NULL;
}
memset(q, 0, sizeof(struct hmx_queue));
q->capacity = capacity;
q->idx_mask = capacity - 1;
q->hap_rctx = hap_rctx;
q->desc = (struct hmx_queue_desc *) memalign(64, capacity * sizeof(struct hmx_queue_desc));
if (!q->desc) {
FARF(ERROR, "hmx-queue: failed to allocate HMX queue descriptors\n");
return NULL;
}
memset(q->desc, 0, capacity * sizeof(struct hmx_queue_desc));
const size_t stack_size = HMX_QUEUE_THREAD_STACK_SIZE;
q->stack = (unsigned char *) memalign(64, stack_size);
if (!q->stack) {
FARF(ERROR, "hmx-queue: thread stack allocation failed (%zu bytes)", stack_size);
return NULL;
}
memset(q->stack, 0, stack_size);
// Match caller thread priority (same pattern as worker-pool.c).
int prio = qurt_thread_get_priority(qurt_thread_get_id());
if (prio < 1) {
prio = 1;
}
if (prio > QURT_LOWEST_PRIO) {
prio = QURT_LOWEST_PRIO;
}
qurt_thread_attr_t attr;
qurt_thread_attr_init(&attr);
qurt_thread_attr_set_stack_addr(&attr, q->stack);
qurt_thread_attr_set_stack_size(&attr, stack_size);
qurt_thread_attr_set_priority(&attr, prio);
qurt_thread_attr_set_name(&attr, "hmx-queue");
int err = qurt_thread_create(&q->thread, &attr, hmx_queue_thread, q);
if (err) {
FARF(ERROR, "hmx-worker: thread create failed (%d)", err);
return NULL;
}
FARF(HIGH, "hmx-queue: capacity %u\n", capacity);
return q;
}
void hmx_queue_delete(struct hmx_queue * q) {
if (!q) {
return;
}
// Tell the worker to exit.
hmx_queue_flush(q);
hmx_queue_signal(q, HMX_QUEUE_KILL);
hmx_queue_flush(q);
int status;
qurt_thread_join(q->thread, &status);
free(q->desc);
free(q->stack);
free(q);
}

View File

@@ -0,0 +1,134 @@
#ifndef HMX_QUEUE_H
#define HMX_QUEUE_H
#include <stdbool.h>
#include <stdint.h>
#include <stdatomic.h>
#include <hexagon_types.h>
#include <qurt_thread.h>
#include <qurt_futex.h>
#include <HAP_farf.h>
#include "hex-utils.h"
#ifdef __cplusplus
extern "C" {
#endif
#define HMX_QUEUE_THREAD_STACK_SIZE (16 * 1024)
#define HMX_QUEUE_POLL_COUNT 2000
typedef void (*hmx_queue_func)(void *);
// Dummy funcs used as signals
enum hmx_queue_signal {
HMX_QUEUE_NOOP = 0, // aka NULL
HMX_QUEUE_SUSPEND,
HMX_QUEUE_KILL
};
struct hmx_queue_desc {
hmx_queue_func func;
void * data;
atomic_uint done;
};
struct hmx_queue {
struct hmx_queue_desc * desc;
atomic_uint idx_write; // updated by producer (push)
atomic_uint idx_read; // updated by consumer (process)
unsigned int idx_pop; // updated by producer (pop)
uint32_t idx_mask;
uint32_t capacity;
atomic_uint seqn; // incremented for all pushes, used with futex
qurt_thread_t thread;
void * stack;
uint32_t hap_rctx;
bool hmx_locked;
};
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx);
void hmx_queue_delete(struct hmx_queue * q);
static inline struct hmx_queue_desc hmx_queue_make_desc(hmx_queue_func func, void * data) {
struct hmx_queue_desc d = { func, data };
return d;
}
static inline bool hmx_queue_push(struct hmx_queue * q, struct hmx_queue_desc d) {
unsigned int ir = atomic_load(&q->idx_read);
unsigned int iw = q->idx_write;
if (((iw + 1) & q->idx_mask) == ir) {
FARF(HIGH, "hmx-queue-push: queue is full\n");
return false;
}
atomic_store(&d.done, 0);
FARF(HIGH, "hmx-queue-push: iw %u func %p data %p\n", iw, d.func, d.data);
q->desc[iw] = d;
atomic_store(&q->idx_write, (iw + 1) & q->idx_mask);
// wake up our thread
atomic_fetch_add(&q->seqn, 1);
qurt_futex_wake(&q->seqn, 1);
return true;
}
static inline bool hmx_queue_signal(struct hmx_queue *q, enum hmx_queue_signal sig) {
return hmx_queue_push(q, hmx_queue_make_desc((hmx_queue_func) sig, NULL));
}
static inline bool hmx_queue_empty(struct hmx_queue * q) {
return q->idx_pop == q->idx_write;
}
static inline uint32_t hmx_queue_depth(struct hmx_queue * q) {
return (q->idx_read - q->idx_read) & q->idx_mask;
}
static inline uint32_t hmx_queue_capacity(struct hmx_queue * q) {
return q->capacity;
}
static inline struct hmx_queue_desc hmx_queue_pop(struct hmx_queue * q) {
unsigned int ip = q->idx_pop;
unsigned int iw = q->idx_write;
struct hmx_queue_desc rd = { NULL, NULL };
if (ip == iw) {
return rd;
}
// Wait for desc to complete
struct hmx_queue_desc * d = &q->desc[ip];
while (!atomic_load(&d->done)) {
FARF(HIGH, "hmx-queue-pop: waiting for HMX queue : %u\n", ip);
hex_pause();
}
rd = *d;
q->idx_pop = (ip + 1) & q->idx_mask;
FARF(HIGH, "hmx-queue-pop: ip %u func %p data %p\n", ip, rd.func, rd.data);
return rd;
}
static inline void hmx_queue_flush(struct hmx_queue * q) {
while (hmx_queue_pop(q).func != NULL) ;
}
static inline void hmx_queue_suspend(struct hmx_queue *q) {
hmx_queue_signal(q, HMX_QUEUE_SUSPEND);
hmx_queue_flush(q);
}
#ifdef __cplusplus
} // extern "C"
#endif
#endif /* HMX_QUEUE_H */

View File

@@ -14,10 +14,6 @@
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) {
asm volatile("bias = mxmem2(%0)" :: "r"(scales));
}
// Initialise aligned 256-byte area with scale vector + zero padding.
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
HVX_Vector *pv = (HVX_Vector *)out_scales;
@@ -25,58 +21,6 @@ static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vecto
*pv = Q6_V_vzero();
}
// Load multiple contiguous tiles with :deep streaming.
// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt].
// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank
// boundary, otherwise the mxmem instruction will raise a precise bus error.
// Callers must ensure their VTCM layout satisfies this constraint.
static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles,
const __fp16 *col_tiles,
size_t n_tiles) {
size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1;
asm volatile(
"{ activation.hf = mxmem(%0, %1):deep\n"
"weight.hf = mxmem(%2, %3) }\n"
:: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit)
: "memory");
}
// Load a single activation+weight tile pair (no :deep streaming).
// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula
// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047.
// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation
// places a tile near a 4 MB bank boundary, the oversized region crosses it and
// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly
// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047).
static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile,
const __fp16 *wt_tile) {
asm volatile(
"{ activation.hf = mxmem(%0, %1)\n"
"weight.hf = mxmem(%2, %3) }\n"
:: "r"(act_tile), "r"(2047),
"r"(wt_tile), "r"(2047)
: "memory");
}
static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) {
// Use the combined convert-and-store instruction (matches the reference
// Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence
// "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter.
asm volatile(
"mxmem(%0, %1):after.hf = acc\n"
:: "r"(out), "r"(0)
: "memory");
}
// Compute inner product of two vectors of tiles and store result.
static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out,
const __fp16 *row_tiles,
const __fp16 *col_tiles,
size_t n_tiles) {
hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles);
hmx_consume_accumulator_fp16(out);
}
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {

View File

@@ -2,6 +2,7 @@
#define HTP_CTX_H
#include "hex-dma.h"
#include "hmx-queue.h"
#include "htp-ops.h"
#include "worker-pool.h"
@@ -30,6 +31,8 @@ struct htp_spad {
uint32_t size_per_thread; // size per thread
};
struct htp_context;
// Context while processing an Op
// TODO: fold this into the main context
struct htp_ops_context {
@@ -72,6 +75,10 @@ struct htp_context {
atomic_bool vtcm_needs_release;
struct htp_ops_context octx;
#ifdef HTP_HAS_HMX
struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap
#endif
};
int op_matmul(struct htp_ops_context * octx);

View File

@@ -91,7 +91,12 @@ enum htp_op_code {
#define HTP_OP_MAX_BUFS 8
#define HTP_OP_MAX_REQS 256
#define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS)
#if __HVX_ARCH__ < 75
#define HTP_OP_MAX_VMEM (3167538380u)
#else
#define HTP_OP_MAX_VMEM (3221225472u)
#endif
enum htp_tensor_flags {
HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights)

View File

@@ -116,9 +116,14 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
}
static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) {
#if __HVX_ARCH__ >= 81
HVX_Vector q0 = Q6_Vqf32_equals_Vsf(v0);
HVX_Vector q1 = Q6_Vqf32_equals_Vsf(v1);
#else
const HVX_Vector zero = Q6_V_vzero();
HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
#endif
return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0));
}

View File

@@ -18,8 +18,9 @@
#include <remote.h>
#include <string.h>
#include "hex-dma.h"
#include "hex-utils.h"
#include "hex-dma.h"
#include "hmx-queue.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
@@ -324,6 +325,14 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
#ifdef HTP_HAS_HMX
ctx->hmx_enabled = use_hmx;
ctx->hmx_queue = NULL;
if (use_hmx) {
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
if (!ctx->hmx_queue) {
FARF(ERROR, "hmx-queue-create failed");
ctx->hmx_enabled = false;
}
}
FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx);
#endif
@@ -389,7 +398,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
}
#ifdef HTP_HAS_HMX
ctx->hmx_enabled = 0;
if (ctx->hmx_queue) {
hmx_queue_delete(ctx->hmx_queue);
ctx->hmx_queue = NULL;
}
ctx->hmx_enabled = false;
#endif
vtcm_free(ctx);

View File

@@ -30,6 +30,8 @@ extern "C" {
void ggml_print_backtrace(void);
uint64_t ggml_graph_next_uid(void);
#ifndef MIN
# define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
@@ -338,6 +340,10 @@ struct ggml_cgraph {
struct ggml_hash_set visited_hash_set;
enum ggml_cgraph_eval_order order;
// an optional identifier that can be utilized to recognize same graphs if two non-zero values match
// a value of 0 means it is not set and should be ignored
uint64_t uid;
};
// returns a slice of cgraph with nodes [i0, i1)

View File

@@ -250,6 +250,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break;
case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break;
case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break;
case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");
@@ -1818,6 +1819,23 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_roll(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_ROLL);
char base[256];
char name[256];
snprintf(base, 256, "kernel_roll_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_PAD);

View File

@@ -152,6 +152,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_roll (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@@ -1043,6 +1043,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
case GGML_UNARY_OP_XIELU:
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
default:
return false;
@@ -1137,6 +1138,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_ARGSORT:
case GGML_OP_TOP_K:
case GGML_OP_ARANGE:
case GGML_OP_ROLL:
return true;
case GGML_OP_FLASH_ATTN_EXT:
// for new head sizes, add checks here
@@ -1159,6 +1161,23 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
if (op->src[1]->type != op->src[2]->type) {
return false;
}
switch (op->src[1]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
break;
case GGML_TYPE_BF16:
if (!has_bfloat) {
return false;
}
break;
default:
return false;
}
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:

View File

@@ -127,6 +127,7 @@
#define OP_UNARY_NUM_CEIL 118
#define OP_UNARY_NUM_ROUND 119
#define OP_UNARY_NUM_TRUNC 120
#define OP_UNARY_NUM_XIELU 121
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
#define OP_SUM_ROWS_NUM_MEAN 11
@@ -1016,6 +1017,29 @@ typedef struct {
int32_t p1;
} ggml_metal_kargs_pad_reflect_1d;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t s0;
int32_t s1;
int32_t s2;
int32_t s3;
} ggml_metal_kargs_roll;
typedef struct {
uint64_t nb1;
int dim;

View File

@@ -410,6 +410,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx);
} break;
case GGML_OP_ROLL:
{
n_fuse = ggml_metal_op_roll(ctx, idx);
} break;
case GGML_OP_ARANGE:
{
n_fuse = ggml_metal_op_arange(ctx, idx);
@@ -787,6 +791,13 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
args.max = ggml_get_op_params_f32(op, 1);
}
if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) {
args.slope = ggml_get_op_params_f32(op, 1); // alpha_n
args.scale = ggml_get_op_params_f32(op, 2); // alpha_p
args.bias = ggml_get_op_params_f32(op, 3); // beta
args.val = ggml_get_op_params_f32(op, 4); // eps
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
if (pipeline.c4) {
@@ -3938,6 +3949,59 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_roll(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t s0 = ggml_get_op_params_i32(op, 0);
const int32_t s1 = ggml_get_op_params_i32(op, 1);
const int32_t s2 = ggml_get_op_params_i32(op, 2);
const int32_t s3 = ggml_get_op_params_i32(op, 3);
ggml_metal_kargs_roll args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.s0 =*/ s0,
/*.s1 =*/ s1,
/*.s2 =*/ s2,
/*.s3 =*/ s3
};
auto pipeline = ggml_metal_library_get_pipeline_roll(lib, op);
const int nth = std::min(1024, ne0);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
return 1;
}
int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

View File

@@ -81,6 +81,7 @@ int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_roll (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);

View File

@@ -1177,6 +1177,15 @@ kernel void kernel_unary_impl(
if (FC_OP == OP_UNARY_NUM_TRUNC) {
dst_ptr[i0] = (T) trunc(x);
}
if (FC_OP == OP_UNARY_NUM_XIELU) {
const TC xi = x;
const TC gate = TC(xi > TC(0.0f));
const TC clamped = fmin(xi, TC(args.val));
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
}
}
#undef FC_OP
@@ -5238,6 +5247,40 @@ kernel void kernel_upscale_bicubic_f32(
}
}
kernel void kernel_roll_f32(
constant ggml_metal_kargs_roll & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
device const float * src0_ptr = (device const float *) src0;
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// apply shifts and wrap around
int64_t i00 = i0 - args.s0;
int64_t i01 = i1 - args.s1;
int64_t i02 = i2 - args.s2;
int64_t i03 = i3 - args.s3;
if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
dst_ptr[dst_idx] = src0_ptr[src_idx];
}
}
kernel void kernel_pad_f32(
constant ggml_metal_kargs_pad & args,
device const char * src0,

View File

@@ -7,3 +7,26 @@ ggml_add_backend_library(ggml-rpc
if (WIN32)
target_link_libraries(ggml-rpc PRIVATE ws2_32)
endif()
# RDMA auto-detection (Linux only, requires libibverbs)
if (NOT WIN32 AND NOT APPLE)
find_library(IBVERBS_LIB ibverbs)
if (IBVERBS_LIB)
option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" ON)
else()
option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" OFF)
endif()
else()
set(GGML_RPC_RDMA OFF CACHE BOOL "RDMA not available on this platform" FORCE)
endif()
if (GGML_RPC_RDMA)
if (NOT IBVERBS_LIB)
find_library(IBVERBS_LIB ibverbs REQUIRED)
endif()
target_compile_definitions(ggml-rpc PRIVATE GGML_RPC_RDMA)
target_link_libraries(ggml-rpc PRIVATE ${IBVERBS_LIB})
message(STATUS " RDMA transport enabled (auto-detected)")
else()
message(STATUS " RDMA transport disabled")
endif()

View File

@@ -3,7 +3,9 @@
#include "ggml-backend-impl.h"
#include "ggml-cpp.h"
#include <array>
#include <cinttypes>
#include <optional>
#include <string>
#include <vector>
#include <memory>
@@ -31,6 +33,14 @@
#include <filesystem>
#include <algorithm>
#ifdef GGML_RPC_RDMA
# include <infiniband/verbs.h>
# include <time.h>
# ifndef _WIN32
# include <poll.h>
# endif
#endif // GGML_RPC_RDMA
static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
#define LOG_DBG(...) \
@@ -49,17 +59,116 @@ typedef int sockfd_t;
#endif
// cross-platform socket
#ifdef GGML_RPC_RDMA
static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock)
static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB
static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes
using rdma_gid_t = std::array<uint8_t, RDMA_GID_SIZE>;
struct rdma_conn {
struct ibv_context * ctx = nullptr;
struct ibv_pd * pd = nullptr;
struct ibv_cq * scq = nullptr; // send completions
struct ibv_cq * rcq = nullptr; // recv completions
struct ibv_qp * qp = nullptr;
void * tx_buf = nullptr;
struct ibv_mr * tx_mr = nullptr;
void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous
struct ibv_mr * rx_mr = nullptr;
int rx_head = 0;
uint32_t max_inline = 0;
uint8_t * rx_slot(int i) const {
return static_cast<uint8_t *>(rx_buf) + static_cast<size_t>(i) * RDMA_CHUNK;
}
bool post_rx(int i) {
struct ibv_sge sge = {};
sge.addr = (uintptr_t)rx_slot(i);
sge.length = RDMA_CHUNK;
sge.lkey = rx_mr->lkey;
struct ibv_recv_wr wr = {}, * bad = nullptr;
wr.wr_id = (uint64_t)i;
wr.sg_list = &sge;
wr.num_sge = 1;
return ibv_post_recv(qp, &wr, &bad) == 0;
}
~rdma_conn() {
if (tx_mr) ibv_dereg_mr(tx_mr);
if (rx_mr) ibv_dereg_mr(rx_mr);
free(tx_buf);
free(rx_buf);
if (qp) ibv_destroy_qp(qp);
if (scq) ibv_destroy_cq(scq);
if (rcq) ibv_destroy_cq(rcq);
if (pd) ibv_dealloc_pd(pd);
if (ctx) ibv_close_device(ctx);
}
};
// Local RDMA parameters captured during the probe phase and later consumed
// by rdma_activate() after the remote side's caps arrive via HELLO.
struct rdma_local_info {
uint32_t qpn = 0;
uint32_t psn = 0;
uint8_t gid[RDMA_GID_SIZE] = {};
uint8_t ib_port = 0;
int gid_idx = 0;
enum ibv_mtu path_mtu = IBV_MTU_1024;
};
#endif // GGML_RPC_RDMA
// conn_caps size for transport-agnostic capability exchange
static constexpr size_t RPC_CONN_CAPS_SIZE = 24;
// conn_caps RDMA layout helper
#ifdef GGML_RPC_RDMA
struct rdma_caps {
uint32_t qpn;
uint32_t psn;
uint8_t gid[RDMA_GID_SIZE];
};
static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size");
#endif // GGML_RPC_RDMA
// Forward declarations for transport function pointers
struct socket_t;
static bool tcp_send_impl(socket_t * sock, const void * data, size_t size);
static bool tcp_recv_impl(socket_t * sock, void * data, size_t size);
struct socket_t {
sockfd_t fd;
bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl;
bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl;
#ifdef GGML_RPC_RDMA
std::unique_ptr<rdma_conn> rdma;
rdma_local_info rdma_local = {};
#endif // GGML_RPC_RDMA
socket_t(sockfd_t fd) : fd(fd) {}
~socket_t() {
#ifdef GGML_RPC_RDMA
rdma.reset();
#endif // GGML_RPC_RDMA
LOG_DBG("[%s] closing socket %d\n", __func__, this->fd);
#ifdef _WIN32
closesocket(this->fd);
if (fd != INVALID_SOCKET) closesocket(this->fd);
#else
close(this->fd);
if (fd >= 0) close(this->fd);
#endif
}
// Advertise local transport capabilities into conn_caps.
// May probe RDMA and store the probe on this socket for update_caps.
void get_caps(uint8_t * caps);
// Activate transport upgrade based on remote conn_caps using the probe
// previously stored by get_caps.
void update_caps(const uint8_t * remote_caps);
};
// macro for nicer error messages on server crash
@@ -115,10 +224,16 @@ static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
struct rpc_msg_hello_req {
uint8_t conn_caps[RPC_CONN_CAPS_SIZE];
};
struct rpc_msg_hello_rsp {
uint8_t major;
uint8_t minor;
uint8_t patch;
uint8_t padding;
uint8_t conn_caps[RPC_CONN_CAPS_SIZE];
};
struct rpc_msg_device_count_rsp {
@@ -414,27 +529,414 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
return true;
}
static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
return false;
}
return send_data(sockfd, msg, msg_size);
// TCP transport implementations (for function-pointer dispatch)
static bool tcp_send_impl(socket_t * sock, const void * data, size_t size) {
return send_data(sock->fd, data, size);
}
static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
static bool tcp_recv_impl(socket_t * sock, void * data, size_t size) {
return recv_data(sock->fd, data, size);
}
// RDMA transport (performance-optimized, auto-negotiated)
#ifdef GGML_RPC_RDMA
static bool rdma_send_impl(socket_t * sock, const void * data, size_t size);
static bool rdma_recv_impl(socket_t * sock, void * data, size_t size);
static inline bool tcp_peer_closed(int fd) {
if (fd < 0) return false;
#ifndef _WIN32
struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 };
int r = poll(&pfd, 1, 0);
return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP));
#else
return false;
#endif
}
static inline bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc, int tcp_fd) {
for (uint64_t s = 0; ; s++) {
int n = ibv_poll_cq(cq, 1, wc);
if (n > 0) {
if (wc->status != IBV_WC_SUCCESS) {
GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n",
wc->status, ibv_wc_status_str(wc->status), wc->vendor_err);
}
return wc->status == IBV_WC_SUCCESS;
}
if (n < 0) return false;
if ((s & 0xFFFFF) == 0 && s > 0) {
if (tcp_peer_closed(tcp_fd)) {
return false;
}
}
}
}
static bool rdma_send(rdma_conn * c, const void * data, size_t size, int tcp_fd) {
const uint8_t * src = (const uint8_t *)data;
size_t rem = size;
while (rem > 0) {
size_t chunk = std::min(rem, RDMA_CHUNK);
struct ibv_sge sge = {};
struct ibv_send_wr wr = {}, * bad = nullptr;
wr.opcode = IBV_WR_SEND;
wr.sg_list = &sge;
wr.num_sge = 1;
if (chunk <= c->max_inline) {
sge.addr = (uintptr_t)src;
sge.length = chunk;
wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE;
} else {
memcpy(c->tx_buf, src, chunk);
sge.addr = (uintptr_t)c->tx_buf;
sge.length = chunk;
sge.lkey = c->tx_mr->lkey;
wr.send_flags = IBV_SEND_SIGNALED;
}
if (ibv_post_send(c->qp, &wr, &bad) != 0) return false;
struct ibv_wc wc;
if (!rdma_poll(c->scq, &wc, tcp_fd)) return false;
src += chunk;
rem -= chunk;
}
return true;
}
static bool rdma_recv(rdma_conn * c, void * data, size_t size, int tcp_fd) {
uint8_t * dst = (uint8_t *)data;
size_t rem = size;
while (rem > 0) {
struct ibv_wc wc;
if (!rdma_poll(c->rcq, &wc, tcp_fd)) return false;
int slot = (int)wc.wr_id;
size_t got = wc.byte_len;
memcpy(dst, c->rx_slot(slot), got);
if (!c->post_rx(slot)) return false;
dst += got;
rem -= got;
}
return true;
}
static bool rdma_send_impl(socket_t * sock, const void * data, size_t size) {
return rdma_send(sock->rdma.get(), data, size, sock->fd);
}
static bool rdma_recv_impl(socket_t * sock, void * data, size_t size) {
return rdma_recv(sock->rdma.get(), data, size, sock->fd);
}
// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address.
// Used to match the socket's local IP against the kernel's GID table so that
// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly:
// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4)
// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape)
// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is
// Returns std::nullopt on unsupported family or getsockname failure.
static std::optional<rdma_gid_t> rdma_build_target_gid(sockfd_t tcp_fd) {
sockaddr_storage addr = {};
socklen_t addr_len = sizeof(addr);
if (getsockname(tcp_fd, reinterpret_cast<sockaddr *>(&addr), &addr_len) != 0) {
return std::nullopt;
}
rdma_gid_t target = {};
if (addr.ss_family == AF_INET) {
const auto * a = reinterpret_cast<const sockaddr_in *>(&addr);
target[10] = 0xff;
target[11] = 0xff;
memcpy(&target[12], &a->sin_addr, 4);
return target;
}
if (addr.ss_family == AF_INET6) {
const auto * a = reinterpret_cast<const sockaddr_in6 *>(&addr);
memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE);
return target;
}
return std::nullopt;
}
static rdma_conn * rdma_probe(sockfd_t tcp_fd, rdma_local_info * out) {
const char * dev_env = std::getenv("GGML_RDMA_DEV");
const char * gid_env = std::getenv("GGML_RDMA_GID");
auto target_gid = rdma_build_target_gid(tcp_fd);
if (!target_gid) {
return nullptr;
}
const uint8_t ib_port = 1;
int num_devs = 0;
ibv_device ** devs = ibv_get_device_list(&num_devs);
if (!devs || num_devs == 0) return nullptr;
ibv_context * ibctx = nullptr;
const char * matched_dev = nullptr;
int gid_idx = gid_env ? atoi(gid_env) : -1;
int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB
for (int d = 0; d < num_devs; d++) {
const char * dn = ibv_get_device_name(devs[d]);
if (dev_env && strcmp(dev_env, dn) != 0) continue;
ibv_context * ctx = ibv_open_device(devs[d]);
if (!ctx) continue;
ibv_port_attr pa;
if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; }
int found_gid = gid_idx;
int found_version = IBV_GID_TYPE_IB;
if (found_gid < 0) {
// Find a GID on this port whose bytes equal the local TCP address
// (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1
// (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths
// are avoided. ibv_query_gid_ex returns gid+type in one call.
int v2_idx = -1;
int v1_idx = -1;
for (int i = 0; i < pa.gid_tbl_len; i++) {
ibv_gid_entry entry = {};
if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue;
if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue;
if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) {
v2_idx = i;
} else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) {
v1_idx = i;
}
}
if (v2_idx >= 0) {
found_gid = v2_idx;
found_version = IBV_GID_TYPE_ROCE_V2;
} else if (v1_idx >= 0) {
found_gid = v1_idx;
found_version = IBV_GID_TYPE_ROCE_V1;
}
} else {
// Explicit GID index from GGML_RDMA_GID — fetch its type for logging.
ibv_gid_entry entry = {};
if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) {
found_version = entry.gid_type;
}
}
if (found_gid >= 0) {
ibctx = ctx;
gid_idx = found_gid;
gid_version = found_version;
matched_dev = dn;
out->path_mtu = pa.active_mtu;
break;
}
ibv_close_device(ctx);
}
ibv_free_device_list(devs);
if (!ibctx) return nullptr;
out->ib_port = ib_port;
out->gid_idx = gid_idx;
// unique_ptr owns ibctx and every subsequent resource via ~rdma_conn(),
// so each failure path is a plain `return nullptr;`.
auto c = std::make_unique<rdma_conn>();
c->ctx = ibctx;
c->pd = ibv_alloc_pd(ibctx);
if (!c->pd) return nullptr;
c->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0);
c->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0);
if (!c->scq || !c->rcq) return nullptr;
ibv_qp_init_attr qia = {};
qia.send_cq = c->scq;
qia.recv_cq = c->rcq;
qia.qp_type = IBV_QPT_RC;
qia.cap.max_send_wr = 4;
qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4;
qia.cap.max_send_sge = 1;
qia.cap.max_recv_sge = 1;
qia.cap.max_inline_data = 256;
c->qp = ibv_create_qp(c->pd, &qia);
if (!c->qp) return nullptr;
c->max_inline = qia.cap.max_inline_data;
c->tx_buf = aligned_alloc(4096, RDMA_CHUNK);
c->rx_buf = aligned_alloc(4096, static_cast<size_t>(RDMA_RX_DEPTH) * RDMA_CHUNK);
if (!c->tx_buf || !c->rx_buf) return nullptr;
c->tx_mr = ibv_reg_mr(c->pd, c->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE);
c->rx_mr = ibv_reg_mr(c->pd, c->rx_buf, static_cast<size_t>(RDMA_RX_DEPTH) * RDMA_CHUNK,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
if (!c->tx_mr || !c->rx_mr) return nullptr;
ibv_gid local_gid;
if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return nullptr;
out->qpn = c->qp->qp_num;
out->psn = c->qp->qp_num & 0xffffff;
memcpy(out->gid, &local_gid, RDMA_GID_SIZE);
const char * ver_str = "";
if (gid_version == IBV_GID_TYPE_ROCE_V2) {
ver_str = " RoCEv2";
} else if (gid_version == IBV_GID_TYPE_ROCE_V1) {
ver_str = " RoCEv1";
}
GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n",
matched_dev, gid_idx, ver_str, out->qpn, c->max_inline);
return c.release();
}
// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS.
// On success, the connection is live and ready for rdma_send/rdma_recv.
static bool rdma_activate(rdma_conn * c, const rdma_local_info * local,
uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) {
// RESET -> INIT
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_INIT;
a.port_num = local->ib_port;
a.pkey_index = 0;
a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
return false;
}
}
for (int i = 0; i < RDMA_RX_DEPTH; i++) {
if (!c->post_rx(i)) return false;
}
// INIT -> RTR
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_RTR;
a.path_mtu = local->path_mtu;
a.dest_qp_num = remote_qpn;
a.rq_psn = remote_psn;
a.max_dest_rd_atomic = 1;
a.min_rnr_timer = 1;
a.ah_attr.is_global = 1;
memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE);
a.ah_attr.grh.hop_limit = 1;
a.ah_attr.grh.sgid_index = local->gid_idx;
a.ah_attr.dlid = 0;
a.ah_attr.port_num = local->ib_port;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN |
IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) {
return false;
}
}
// RTR -> RTS
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_RTS;
a.timeout = 14;
a.retry_cnt = 7;
a.rnr_retry = 7;
a.sq_psn = local->psn;
a.max_rd_atomic = 1;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY |
IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) {
return false;
}
}
GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n",
local->qpn, remote_qpn, 128 << local->path_mtu, RDMA_RX_DEPTH);
return true;
}
#endif // GGML_RPC_RDMA
// ---------------------------------------------------------------------------
// socket_t transport capability methods
// ---------------------------------------------------------------------------
void socket_t::get_caps(uint8_t * caps) {
memset(caps, 0, RPC_CONN_CAPS_SIZE);
#ifdef GGML_RPC_RDMA
rdma_local = {};
rdma.reset(rdma_probe(fd, &rdma_local));
if (rdma) {
rdma_caps rc = {};
rc.qpn = rdma_local.qpn;
rc.psn = rdma_local.psn;
memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE);
memcpy(caps, &rc, sizeof(rc));
}
#endif // GGML_RPC_RDMA
}
void socket_t::update_caps(const uint8_t * remote_caps) {
#ifdef GGML_RPC_RDMA
if (!rdma) {
return;
}
rdma_caps rc = {};
memcpy(&rc, remote_caps, sizeof(rc));
if (rc.qpn == 0) {
rdma.reset();
return;
}
if (rdma_activate(rdma.get(), &rdma_local, rc.qpn, rc.psn, rc.gid)) {
fn_send = rdma_send_impl;
fn_recv = rdma_recv_impl;
} else {
GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n");
rdma.reset();
}
#else
(void)remote_caps;
#endif // GGML_RPC_RDMA
}
// unified transport dispatch (via function pointers)
static bool send_data(socket_t * sock, const void * data, size_t size) {
return sock->fn_send(sock, data, size);
}
static bool recv_data(socket_t * sock, void * data, size_t size) {
return sock->fn_recv(sock, data, size);
}
static bool send_msg(socket_t * sock, const void * msg, size_t msg_size) {
if (!send_data(sock, &msg_size, sizeof(msg_size))) {
return false;
}
return send_data(sock, msg, msg_size);
}
static bool recv_msg(socket_t * sock, void * msg, size_t msg_size) {
uint64_t size;
if (!recv_data(sockfd, &size, sizeof(size))) {
if (!recv_data(sock, &size, sizeof(size))) {
return false;
}
if (size != msg_size) {
return false;
}
return recv_data(sockfd, msg, msg_size);
return recv_data(sock, msg, msg_size);
}
static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
static bool recv_msg(socket_t * sock, std::vector<uint8_t> & input) {
uint64_t size;
if (!recv_data(sockfd, &size, sizeof(size))) {
if (!recv_data(sock, &size, sizeof(size))) {
return false;
}
try {
@@ -443,7 +945,7 @@ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size);
return false;
}
return recv_data(sockfd, input.data(), size);
return recv_data(sock, input.data(), size);
}
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
@@ -452,7 +954,11 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
return false;
}
host = endpoint.substr(0, pos);
port = std::stoi(endpoint.substr(pos + 1));
try {
port = std::stoi(endpoint.substr(pos + 1));
} catch (...) {
return false;
}
return true;
}
@@ -460,13 +966,13 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
// No response
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
uint8_t cmd_byte = cmd;
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
if (!send_data(sock.get(), &cmd_byte, sizeof(cmd_byte))) {
return false;
}
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
if (!send_data(sock.get(), &input_size, sizeof(input_size))) {
return false;
}
if (!send_data(sock->fd, input, input_size)) {
if (!send_data(sock.get(), input, input_size)) {
return false;
}
return true;
@@ -478,16 +984,14 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
if (!send_rpc_cmd(sock, cmd, input, input_size)) {
return false;
}
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
// even if we do, we can skip sending output_size from the server for commands with known output size
uint64_t out_size;
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
if (!recv_data(sock.get(), &out_size, sizeof(out_size))) {
return false;
}
if (out_size != output_size) {
return false;
}
if (!recv_data(sock->fd, output, output_size)) {
if (!recv_data(sock.get(), output, output_size)) {
return false;
}
return true;
@@ -495,17 +999,25 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
// RPC client-side implementation
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
rpc_msg_hello_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
// Performs HELLO handshake with transport auto-negotiation.
// Advertises local capabilities via conn_caps; if the server responds with
// matching capabilities, the socket is upgraded transparently.
static bool negotiate_hello(const std::shared_ptr<socket_t> & sock) {
rpc_msg_hello_req request = {};
rpc_msg_hello_rsp response = {};
sock->get_caps(request.conn_caps);
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n",
response.major, response.minor, response.patch);
return false;
}
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
}
sock->update_caps(response.conn_caps);
return true;
}
@@ -527,6 +1039,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
return nullptr;
}
#ifdef _WIN32
if (!initialized) {
WSADATA wsaData;
@@ -543,10 +1056,10 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
if (sock == nullptr) {
return nullptr;
}
if (!check_server_version(sock)) {
if (!negotiate_hello(sock)) {
return nullptr;
}
LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
LOG_DBG("[%s] connected to %s\n", __func__, endpoint.c_str());
sockets[endpoint] = sock;
return sock;
}
@@ -1597,25 +2110,46 @@ rpc_server::~rpc_server() {
}
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
sockfd_t sockfd) {
socket_t * sockfd) {
rpc_server server(backends, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
return;
}
// the first command sent by the client must be HELLO
if (cmd != RPC_CMD_HELLO) {
GGML_LOG_ERROR("Expected HELLO command, update client\n");
return;
}
if (!recv_msg(sockfd, nullptr, 0)) {
// Read input_size and validate protocol version
uint64_t hello_input_size;
if (!recv_data(sockfd, &hello_input_size, sizeof(hello_input_size))) {
return;
}
rpc_msg_hello_rsp response;
server.hello(response);
if (!send_msg(sockfd, &response, sizeof(response))) {
if (hello_input_size != sizeof(rpc_msg_hello_req)) {
GGML_LOG_ERROR("HELLO request size mismatch (%zu vs %zu) — client needs upgrade to protocol v%d.x\n",
(size_t)hello_input_size, sizeof(rpc_msg_hello_req), RPC_PROTO_MAJOR_VERSION);
return;
}
rpc_msg_hello_req req = {};
if (!recv_data(sockfd, &req, sizeof(req))) {
return;
}
rpc_msg_hello_rsp rsp = {};
server.hello(rsp);
// Advertise server transport capabilities based on client's caps
sockfd->get_caps(rsp.conn_caps);
if (!send_msg(sockfd, &rsp, sizeof(rsp))) {
return;
}
// Activate transport upgrade using client's caps
sockfd->update_caps(req.conn_caps);
while (true) {
if (!recv_data(sockfd, &cmd, 1)) {
break;
@@ -1884,6 +2418,12 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
if (!parse_endpoint(endpoint, host, port)) {
return;
}
#ifdef GGML_RPC_RDMA
printf(" transport : TCP (RDMA auto-negotiate enabled)\n");
#else
printf(" transport : TCP\n");
#endif // GGML_RPC_RDMA
#ifdef _WIN32
{
WSADATA wsaData;
@@ -1907,7 +2447,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
}
printf("Accepted client connection\n");
fflush(stdout);
rpc_serve_client(backends, cache_dir, client_socket->fd);
rpc_serve_client(backends, cache_dir, client_socket.get());
printf("Client connection closed\n");
fflush(stdout);
}

View File

@@ -154,6 +154,11 @@ if (GGML_SYCL_GRAPH)
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)
endif()
if (GGML_SYCL_HOST_MEM_FALLBACK)
message(STATUS "find GGML_SYCL_HOST_MEM_FALLBACK")
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_HOST_MEM_FALLBACK)
endif()
if (GGML_SYCL_DEVICE_ARCH)
target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})
target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})

View File

@@ -151,6 +151,25 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
}
template <typename dst_t>
static void dequantize_row_q8_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
int constexpr WARP_K = WARP_SIZE * QK8_0;
const int n_warp = (k + WARP_K - 1) / WARP_K;
GGML_ASSERT(k % QK8_0 == 0);
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
sycl::range<3>(1, 1, WARP_SIZE),
sycl::range<3>(1, 1, WARP_SIZE)),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
dequantize_block_q8_0_reorder(vx, y, k, item_ct1);
});
}
template <typename dst_t>
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
@@ -614,7 +633,12 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
case GGML_TYPE_Q5_1:
return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
if (dst->src[0]->extra &&
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q8_0_sycl_reorder;
} else {
return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
}
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_sycl;
case GGML_TYPE_Q3_K:
@@ -683,7 +707,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
case GGML_TYPE_Q5_1:
return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
if (dst->src[0]->extra &&
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q8_0_sycl_reorder;
} else {
return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
}
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_sycl;
case GGML_TYPE_Q3_K:

View File

@@ -239,6 +239,34 @@ static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t *
}
// Dequantize Q8_0 from reorder layout: [all qs (k bytes)][all d values]
// Each thread handles one block of QK8_0 elements.
template<typename dst_t>
static void dequantize_block_q8_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t k,
const sycl::nd_item<3> &item_ct1) {
const int64_t i = item_ct1.get_group(2);
const int64_t tid = item_ct1.get_local_id(2);
const int lane_ib = i * WARP_SIZE + tid;
if (lane_ib >= k / QK8_0) {
return;
}
dst_t * y_ptr = yy + lane_ib * QK8_0;
auto qs = (const int8_t*)vx + lane_ib * QK8_0;
auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k) + lane_ib;
const float d = float(*s_ptr);
#pragma unroll
for (int l = 0; l < QK8_0; ++l) {
y_ptr[l] = d * qs[l];
}
}
template<typename dst_t>
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
const sycl::nd_item<3> &item_ct1) {

View File

@@ -615,6 +615,162 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
}
}
static void dequantize_mul_mat_vec_q4_k_reorder(const void *__restrict__ vx,
const float *__restrict__ yy,
float *__restrict__ dst,
const int ncols, int nrows,
const sycl::nd_item<3> &item_ct1) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
// SOA base pointers for the reordered layout:
// [qs: nb * QK_K/2] [scales: nb * K_SCALE_SIZE] [dm: nb * sizeof(half2)]
const int nb = nrows * num_blocks_per_row;
const uint8_t * qs_base = (const uint8_t *)vx;
const uint8_t * scales_base = qs_base + (size_t)nb * (QK_K / 2);
const sycl::half2 * dm_base = (const sycl::half2 *)(scales_base + (size_t)nb * K_SCALE_SIZE);
#if QK_K == 256
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int tid =
item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix =
item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
const int il = tid/step; // 0...3
const int ir = tid - step*il; // 0...7 or 0...3
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
const int l0 = n*(2*ir + in);
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
#if K_QUANTS_PER_ITERATION == 2
uint32_t q32[4];
const uint8_t * q4 = (const uint8_t *)q32;
#else
uint16_t q16[4];
const uint8_t * q4 = (const uint8_t *)q16;
#endif
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const int bi = ib0 + i;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
const sycl::half2 dm_val = dm_base[bi];
const float dall = dm_val[0];
const float dmin = dm_val[1];
const uint16_t * a = (const uint16_t *)(scales_base + bi * K_SCALE_SIZE);
aux[0] = a[im+0] & kmask1;
aux[1] = a[im+2] & kmask1;
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
#if K_QUANTS_PER_ITERATION == 2
const uint32_t * q1 = (const uint32_t *)(qs_base + bi * (QK_K / 2) + q_offset);
const uint32_t * q2 = q1 + 16;
q32[0] = q1[0] & 0x0f0f0f0f;
q32[1] = q1[0] & 0xf0f0f0f0;
q32[2] = q2[0] & 0x0f0f0f0f;
q32[3] = q2[0] & 0xf0f0f0f0;
sycl::float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < 4; ++l) {
s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4];
s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f +
s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) -
dmin * smin;
#else
const uint16_t * q1 = (const uint16_t *)(qs_base + bi * (QK_K / 2) + q_offset);
const uint16_t * q2 = q1 + 32;
q16[0] = q1[0] & 0x0f0f;
q16[1] = q1[0] & 0xf0f0;
q16[2] = q2[0] & 0x0f0f;
q16[3] = q2[0] & 0xf0f0;
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < 2; ++l) {
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#endif
}
#else
const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15
const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
const int step = tid * K_QUANTS_PER_ITERATION;
uint16_t aux16[2];
const uint8_t * s = (const uint8_t *)aux16;
float tmp = 0;
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const int bi = ib0 + i;
const uint8_t * q = qs_base + bi * (QK_K / 2) + step;
const float * y = yy + i*QK_K + step;
const uint16_t * a = (const uint16_t *)(scales_base + bi * K_SCALE_SIZE);
aux16[0] = a[0] & 0x0f0f;
aux16[1] = (a[0] >> 4) & 0x0f0f;
const sycl::half2 dm_val = dm_base[bi];
const float d = (float)dm_val[0];
const float m = (float)dm_val[1];
float sum = 0.f;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
+ y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
+ y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
+ y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
}
tmp += sum;
}
#endif
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
if (tid == 0) {
dst[row] = tmp;
}
}
/*
DPCT1110:7: The total declared local variable size in device function
dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register
@@ -864,6 +1020,129 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
}
}
static void dequantize_mul_mat_vec_q6_k_reorder(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows,
const sycl::nd_item<3> &item_ct1) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
// SOA base pointers for the reordered layout:
// [ql: nb * QK_K/2] [qh: nb * QK_K/4] [scales: nb * QK_K/16] [d: nb * sizeof(half)]
const int nb = nrows * num_blocks_per_row;
const uint8_t * ql_base = (const uint8_t *)vx;
const uint8_t * qh_base = ql_base + (size_t)nb * (QK_K / 2);
const int8_t * scales_base = (const int8_t *)(qh_base + (size_t)nb * (QK_K / 4));
const sycl::half * d_base = (const sycl::half *)((const uint8_t *)scales_base + (size_t)nb * (QK_K / 16));
#if QK_K == 256
const int tid =
item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix =
item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
const int is = 0;
#else
const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int is = in / 4;
#endif
const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0;
const int s_offset = 8*im + is;
const int y_offset = 128*im + l0;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const int bi = ib0 + i;
const float * y = yy + i * QK_K + y_offset;
const uint8_t * ql = ql_base + bi * (QK_K / 2) + ql_offset;
const uint8_t * qh = qh_base + bi * (QK_K / 4) + qh_offset;
const int8_t * s = scales_base + bi * (QK_K / 16) + s_offset;
const float d = d_base[bi];
#if K_QUANTS_PER_ITERATION == 1
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
tmp += sum;
#else
float sum = 0;
for (int l = 0; l < 4; ++l) {
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
}
tmp += sum;
#endif
}
#else
const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7
const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3
const int step = tid * K_QUANTS_PER_ITERATION;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const int bi = ib0 + i;
const float * y = yy + i * QK_K + step;
const uint8_t * ql = ql_base + bi * (QK_K / 2) + step;
const uint8_t * qh = qh_base + bi * (QK_K / 4) + step;
const int8_t * s = scales_base + bi * (QK_K / 16);
const float d = d_base[bi];
float sum = 0;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
+ y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
+ y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
+ y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
}
tmp += sum;
}
#endif
// sum up partial sums and write back result
#pragma unroll
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
if (tid == 0) {
dst[row] = tmp;
}
}
static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y,
float *dst, const int ncols,
const int nrows,
@@ -1167,6 +1446,38 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
});
}
static void dequantize_mul_mat_vec_q4_K_sycl_reorder(const void *vx, const float *y,
float *dst, const int ncols,
const int nrows,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q4_k_reorder(vx, y, dst, ncols, nrows, item_ct1);
});
}
static void dequantize_mul_mat_vec_q6_K_sycl_reorder(const void *vx, const float *y,
float *dst, const int ncols,
const int nrows,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q6_k_reorder(vx, y, dst, ncols, nrows, item_ct1);
});
}
void ggml_sycl_op_dequantize_mul_mat_vec(
ggml_backend_sycl_context & ctx,
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -1235,8 +1546,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
case GGML_TYPE_Q4_K:
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
// reorder is currently not supported for dmmv
GGML_ABORT("Unimplemented dequantize case case for q4_k reorder");
dequantize_mul_mat_vec_q4_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
} else {
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
}
@@ -1245,7 +1555,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q6_K:
dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
dequantize_mul_mat_vec_q6_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
} else {
dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
}
break;
case GGML_TYPE_F16:
convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);

View File

@@ -3348,9 +3348,55 @@ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {
sycl::free(ptr, *stream);
}
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
// RAII wrapper for temporary reorder buffers with optional host memory fallback.
// When device allocation fails and GGML_SYCL_HOST_MEM_FALLBACK is enabled,
// falls back to host memory so the reorder kernel can still run (over PCIe).
// Device access to host memory requires Linux kernel 6.8+ (Ubuntu 26.04+).
struct sycl_reorder_temp_buffer {
void * ptr = nullptr;
dpct::queue_ptr stream;
sycl_reorder_temp_buffer(dpct::queue_ptr stream, size_t size) : stream(stream) {
ptr = sycl_ext_malloc_device(stream, size);
#ifdef GGML_SYCL_HOST_MEM_FALLBACK
if (!ptr) {
ptr = sycl::malloc_host(size, *stream);
if (ptr) {
host_fallback = true;
GGML_LOG_WARN("%s: device alloc of %zu bytes failed, using host memory fallback\n", __func__, size);
}
}
#endif
}
~sycl_reorder_temp_buffer() {
if (!ptr) {
return;
}
if (host_fallback) {
sycl::free(ptr, *stream);
} else {
sycl_ext_free(stream, ptr);
}
}
explicit operator bool() const { return ptr != nullptr; }
sycl_reorder_temp_buffer(const sycl_reorder_temp_buffer &) = delete;
sycl_reorder_temp_buffer & operator=(const sycl_reorder_temp_buffer &) = delete;
private:
bool host_fallback = false;
};
static bool reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
dpct::queue_ptr stream) {
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
sycl_reorder_temp_buffer tmp(stream, size);
if (!tmp) {
GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size);
return false;
}
uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr);
sycl::event copy_event;
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
@@ -3379,12 +3425,17 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
if (!g_ggml_sycl_use_async_mem_op) {
reorder_event.wait_and_throw();
}
sycl_ext_free(stream, tmp_buf);
return true;
}
static void reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
static bool reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
dpct::queue_ptr stream) {
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
sycl_reorder_temp_buffer tmp(stream, size);
if (!tmp) {
GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size);
return false;
}
uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr);
sycl::event copy_event;
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
@@ -3413,16 +3464,21 @@ static void reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nr
if (!g_ggml_sycl_use_async_mem_op) {
reorder_event.wait_and_throw();
}
sycl_ext_free(stream, tmp_buf);
return true;
}
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
const int nblocks = size / sizeof(block_q4_K);
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
sycl_reorder_temp_buffer tmp(stream, size);
if (!tmp) {
GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size);
return false;
}
uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr);
sycl::event copy_event;
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
@@ -3451,16 +3507,21 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
if (!g_ggml_sycl_use_async_mem_op) {
reorder_event.wait_and_throw();
}
sycl_ext_free(stream, tmp_buf);
return true;
}
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
GGML_ASSERT(size % sizeof(block_q6_K) == 0);
GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
const int nblocks = size / sizeof(block_q6_K);
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
sycl_reorder_temp_buffer tmp(stream, size);
if (!tmp) {
GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size);
return false;
}
uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr);
sycl::event copy_event;
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
@@ -3499,10 +3560,10 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
if (!g_ggml_sycl_use_async_mem_op) {
reorder_event.wait_and_throw();
}
sycl_ext_free(stream, tmp_buf);
return true;
}
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
uint8_t * data_device = (uint8_t *) src0->data;
size_t ncols = src0->ne[0];
size_t nrows = src0->ne[1];
@@ -3510,20 +3571,16 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
switch (src0->type) {
case GGML_TYPE_Q4_0:
reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
break;
return reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
case GGML_TYPE_Q8_0:
reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream);
break;
return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream);
case GGML_TYPE_Q4_K:
reorder_qw_q4_k(data_device, size, 0, stream);
break;
return reorder_qw_q4_k(data_device, size, 0, stream);
case GGML_TYPE_Q6_K:
reorder_qw_q6_k(data_device, size, 0, stream);
break;
return reorder_qw_q6_k(data_device, size, 0, stream);
default:
GGML_ABORT("reorder_qw() called with unsupported type");
break;
return false;
}
}
@@ -3563,8 +3620,9 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
break;
}
reorder_qw(src0, ctx->stream());
extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
if (reorder_qw(src0, ctx->stream())) {
extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
}
}

View File

@@ -20,6 +20,13 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
#include <vulkan/vulkan.hpp>
// SPIRV-Headers: LunarG Windows SDK uses Include/spirv-headers/spirv.hpp (not spirv/unified1/). MinGW/MSYS2 and
// Linux packages use Khronos layout spirv/unified1/spirv.hpp. See docs/build.md#vulkan.
#if defined(_WIN32) && !defined(__MINGW32__)
#include <spirv-headers/spirv.hpp>
#else
#include <spirv/unified1/spirv.hpp>
#endif
#include <algorithm>
#include <cmath>
@@ -1387,7 +1394,7 @@ struct vk_op_im2col_push_constants {
uint32_t IW; uint32_t IH;
uint32_t OW; uint32_t OH;
uint32_t KW; uint32_t KH;
uint32_t pelements;
uint32_t OH_batch;
uint32_t CHW;
int32_t s0; int32_t s1;
int32_t p0; int32_t p1;
@@ -2131,6 +2138,66 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
// Patch SPIR-V to enable RTE rounding for FP16, avoiding the need for
// separate shader variants compiled with -DRTE16.
std::vector<uint32_t> spv;
if (device->float_controls_rte_fp16) {
const uint32_t* spv_words = reinterpret_cast<const uint32_t *>(spv_data);
size_t word_count = spv_size / sizeof(uint32_t);
spv.assign(spv_words, spv_words + word_count);
// Find insertion points respecting SPIR-V layout order:
// Header(5) -> OpCapability -> OpExtension -> ... -> OpEntryPoint -> OpExecutionMode -> ...
size_t pos = 5; // skip header
size_t cap_insert_pos = pos;
size_t ext_insert_pos = pos;
size_t exec_insert_pos = pos;
uint32_t entry_point_id = 0;
while (pos < spv.size()) {
uint32_t opcode = spv[pos] & spv::OpCodeMask;
uint32_t len = spv[pos] >> spv::WordCountShift;
if (len == 0) break;
if (opcode == spv::OpCapability) {
cap_insert_pos = pos + len;
ext_insert_pos = pos + len;
} else if (opcode == spv::OpExtension) {
ext_insert_pos = pos + len;
} else if (opcode == spv::OpEntryPoint) {
entry_point_id = spv[pos + 2];
exec_insert_pos = pos + len;
} else if (opcode == spv::OpExecutionMode || opcode == spv::OpExecutionModeId) {
exec_insert_pos = pos + len;
} else if (entry_point_id != 0) {
break;
}
pos += len;
}
// Insert from latest position first so earlier indices stay valid.
// OpExecutionMode %entrypoint RoundingModeRTE 16
uint32_t exec_mode[] = { (4u << spv::WordCountShift) | spv::OpExecutionMode, entry_point_id, spv::ExecutionModeRoundingModeRTE, 16 };
spv.insert(spv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode));
// OpExtension "SPV_KHR_float_controls"
const char ext_str[] = "SPV_KHR_float_controls";
size_t ext_str_words = CEIL_DIV(sizeof(ext_str), sizeof(uint32_t));
std::vector<uint32_t> extension(1 + ext_str_words, 0);
extension[0] = (uint32_t)((1 + ext_str_words) << spv::WordCountShift) | spv::OpExtension;
memcpy(&extension[1], ext_str, sizeof(ext_str));
spv.insert(spv.begin() + ext_insert_pos, extension.begin(), extension.end());
// OpCapability RoundingModeRTE
uint32_t capability[] = { (2u << spv::WordCountShift) | spv::OpCapability, spv::CapabilityRoundingModeRTE };
spv.insert(spv.begin() + cap_insert_pos, std::begin(capability), std::end(capability));
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spv.size() * sizeof(uint32_t), spv.data());
}
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
vk::PushConstantRange pcr(
@@ -3079,6 +3146,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
case GGML_TYPE_MXFP4:
lut_size = 4*16;
break;
case GGML_TYPE_NVFP4:
// Same kvalues budget as MXFP4 plus ue4m3_fp32_lut[128] (types.glsl, DATA_A_NVFP4).
lut_size = 4*16 + 128u * (uint32_t)sizeof(float);
break;
default:
break;
}
@@ -3558,6 +3629,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_NVFP4], matmul_nvfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
GGML_ASSERT(device->subgroup_ballot);
@@ -3588,6 +3660,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
#undef CREATE_MM
#undef CREATE_MM2
} else
@@ -3651,6 +3724,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
} else {
CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -3674,6 +3748,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
GGML_ASSERT(device->subgroup_ballot);
@@ -3708,6 +3783,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
#undef CREATE_MM2
#undef CREATE_MM
} else
@@ -3773,6 +3849,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -3819,6 +3896,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -3864,6 +3942,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -3939,6 +4018,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -3983,6 +4063,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_subgroup_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
@@ -4010,6 +4091,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
}
}
// reusing CREATE_MM from the fp32 path
@@ -4108,6 +4190,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
@@ -4133,6 +4216,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -4184,6 +4268,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_NVFP4], "mul_mat_vec_id_nvfp4_f32", arr_dmmv_id_nvfp4_f32_f32_len[reduc16], arr_dmmv_id_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -4239,6 +4324,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
// get_rows
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -4265,6 +4351,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_NVFP4], "get_rows_nvfp4", get_rows_nvfp4_len, get_rows_nvfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -4291,6 +4378,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
@@ -4323,10 +4411,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
if (device->float_controls_rte_fp16 &&
sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
if (sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_len, rms_norm_mul_rope_f32_f16_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
}
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -4351,43 +4438,28 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_rte_len, cpy_f32_q1_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
} else {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
}
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
#define SET_ROWS(itype, rte) \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## rte ## _len, set_rows_q1_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
#define SET_ROWS(itype) \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## _len, set_rows_f32 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## _len, set_rows_f16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## _len, set_rows_bf16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## _len, set_rows_q1_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## _len, set_rows_q4_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## _len, set_rows_q4_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## _len, set_rows_q5_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
if (device->float_controls_rte_fp16) {
SET_ROWS(_i32, _rte)
SET_ROWS(_i64, _rte)
} else {
SET_ROWS(_i32, )
SET_ROWS(_i64, )
}
SET_ROWS(_i32)
SET_ROWS(_i64)
#undef SET_ROWS
@@ -4407,11 +4479,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
return s;
};
bool rte = device->float_controls_rte_fp16;
#define CREATE_BINARY(name, namemod, spec, bindings) \
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
"main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
CREATE_BINARY(add, , {0}, 4)
@@ -4454,13 +4525,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
} else {
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
}
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -4501,19 +4567,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_UNARY(floor)
CREATE_UNARY(trunc)
CREATE_UNARY(sgn)
CREATE_UNARY(exp)
#undef CREATE_UNARY
#define CREATE_UNARY_RTE(name) \
if (device->float_controls_rte_fp16) { \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
} else { \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
}
CREATE_UNARY_RTE(exp)
#undef CREATE_UNARY_RTE
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -4523,13 +4579,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
#define CREATE_GLU(name) \
if (device->float_controls_rte_fp16) { \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
} else { \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
}
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
CREATE_GLU(geglu)
CREATE_GLU(reglu)
@@ -4562,25 +4613,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
} else {
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
}
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
@@ -4642,13 +4682,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
#define IM2COL(bda) \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
if (device->float_controls_rte_fp16) { \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
} else { \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
}
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
if (device->shader_int64 && device->buffer_device_address) {
IM2COL(_bda)
} else {
@@ -6089,6 +6124,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
break;
default:
return nullptr;
@@ -6161,6 +6197,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
break;
default:
return nullptr;
@@ -6227,6 +6264,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
break;
default:
return nullptr;
@@ -6318,6 +6356,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
break;
default:
return nullptr;
@@ -6387,6 +6426,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
break;
default:
return nullptr;
@@ -10024,7 +10064,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
elements = { OW * KW * KH, OH, batch * IC };
const uint32_t CHW = IC * KH * KW;
// Cap X workgroups to limit concurrent IC channel reads.
// The shader loops over X to cover the full CHW dimension.
// AMD prefers a lower limit
const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u;
const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW));
elements = { x_elements, OW, OH * batch };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
} break;
@@ -11687,7 +11733,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const uint32_t pelements = OW * KW * KH;
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
@@ -11699,7 +11744,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
dst_addr,
batch_offset, offset_delta,
IC, IW, IH, OW, OH, KW, KH,
pelements,
OH * batch,
IC * KH * KW,
s0, s1, p0, p1, d0, d1, batch * IC
});
@@ -14317,8 +14362,7 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
}
// conditions for pipeline creation
if (!(ctx->device->float_controls_rte_fp16 &&
sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
if (sizeof(vk_op_rms_norm_mul_rope_push_constants) > ctx->device->properties.limits.maxPushConstantsSize) {
return false;
}
@@ -15373,6 +15417,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
break;
default:
return false;
@@ -15488,6 +15533,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_I32:
return true;
default:

View File

@@ -4,7 +4,7 @@
#include "generic_unary_head.glsl"
#include "dequant_funcs.glsl"
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4)
// 16 invocations needed for init_iq_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else

View File

@@ -1,6 +1,5 @@
#version 450
#include "rte.glsl"
#include "types.glsl"
#if defined(SET_ROWS) && QUANT_K == 1

View File

@@ -450,6 +450,25 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif
#if defined(DATA_A_NVFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint sub = iqs >> 4;
const float d = ue4m3_to_fp32(data_a[a_offset + ib].d[sub]);
const uint j = iqs & 7;
const uint shift = (iqs & 8) >> 1; // 0 or 4
const uint vui0 = uint(data_a[a_offset + ib].qs[sub * 8u + j]);
const uint vui1 = uint(data_a[a_offset + ib].qs[sub * 8u + j + 1]);
const uint qs0 = (vui0 >> shift) & 0xF;
const uint qs1 = (vui1 >> shift) & 0xF;
return vec2(float(kvalues_mxfp4[qs0]), float(kvalues_mxfp4[qs1])) * d * 0.5;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const vec2 v0 = dequantize(ib, iqs, a_offset);
const vec2 v1 = dequantize(ib, iqs + 2u, a_offset);
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(0, 0);
@@ -484,6 +503,12 @@ vec2 get_dm(uint ib, uint a_offset) {
}
#endif
#if defined(DATA_A_NVFP4)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1.0, 0.0);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) {
const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);

View File

@@ -697,6 +697,24 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
}
#endif
#if defined(DATA_A_NVFP4)
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4 {
block_nvfp4 block;
};
float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint sub = (idx & 0x30) >> 4;
const uint iqs = ((idx & 0x30) >> 1) + (idx & 0x7);
const uint shift = (idx & 0x8) >> 1;
const float d = ue4m3_to_fp32(bl.block.d[sub]);
uint qs = uint(bl.block.qs[iqs]);
qs = (qs >> shift) & 0xF;
return float16_t(kvalues_mxfp4[qs] * d * 0.5);
}
#endif
#if defined(DATA_A_Q1_0)
#define dequantFuncA dequantFuncQ1_0
#elif defined(DATA_A_Q4_0)
@@ -743,6 +761,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#elif defined(DATA_A_NVFP4)
#define dequantFuncA dequantFuncNVFP4
#elif defined(DATA_A_F32)
#define dequantFuncA dequantFuncF32
#endif

View File

@@ -0,0 +1,32 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_nvfp4 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq_shmem(gl_WorkGroupSize);
const uint tid = gl_LocalInvocationID.x % 64;
const uint sub = tid / 16;
const uint ir = tid % 16;
const uint ib = 16 * i + ir;
if (ib >= p.nel / 64) {
return;
}
const uint q_idx = 8 * sub;
const uint b_idx = 1024 * i + 64 * ir + 16 * sub;
const float d = ue4m3_to_fp32(data_a[ib].d[sub]);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
data_b[b_idx + l + 8] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
}
}

View File

@@ -1,6 +1,5 @@
#version 450
#include "rte.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"

View File

@@ -1,6 +1,5 @@
#version 450
#include "rte.glsl"
#include "generic_head.glsl"
#include "types.glsl"

View File

@@ -1,7 +1,6 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#include "rte.glsl"
#include "utils.glsl"
#if RMS_NORM_ROPE_FUSION
#include "rope_params.glsl"

View File

@@ -1,6 +1,5 @@
#extension GL_EXT_shader_16bit_storage : require
#include "rte.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

View File

@@ -3,7 +3,6 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#include "rte.glsl"
#include "types.glsl"
layout (push_constant) uniform parameter
@@ -14,7 +13,7 @@ layout (push_constant) uniform parameter
uint IW; uint IH;
uint OW; uint OH;
uint KW; uint KH;
uint pelements;
uint OH_batch;
uint CHW;
int s0; int s1;
int p0; int p1;
@@ -35,82 +34,60 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void im2col(const uint y, const uint z) {
const uint gidx = gl_GlobalInvocationID.x;
void im2col(const uint ow, const uint z_idx) {
const uint oh = z_idx % p.OH;
const uint batch_idx = z_idx / p.OH;
const uint oh = y;
const uint batch = z / p.IC;
const uint ic = z % p.IC;
const uint gidx = gl_LocalInvocationID.x;
const uint src_batch = batch_idx * p.batch_offset;
const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1;
const uint ksize = p.OW * p.KH;
const uint KHKW = p.KH * p.KW;
const uint base_linear_idx = gidx * NUM_ITER;
uint wg_x = gl_WorkGroupID.x;
do {
const uint wg_offset = wg_x * 512;
uint current_kx = base_linear_idx / ksize;
const uint rem = base_linear_idx - (current_kx * ksize);
uint current_ky = rem / p.OW;
uint current_ix = rem % p.OW;
[[unroll]] for (uint i = 0; i < NUM_ITER; ++i) {
const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE;
A_TYPE values[NUM_ITER];
BDA_OFFSET_T offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
values[idx] = A_TYPE(0);
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint linear_idx = base_linear_idx + idx;
if (linear_idx >= p.pelements) {
continue;
}
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
if ((iih < p.IH) && (iiw < p.IW)) {
values[idx] = data_a[src_base + iih * p.IW + iiw];
}
if (++current_ix == p.OW) {
current_ix = 0;
if (++current_ky == p.KH) {
current_ky = 0;
current_kx++;
if (chw_idx >= p.CHW) {
return;
}
}
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint ic = chw_idx / KHKW;
const uint rem = chw_idx - ic * KHKW;
const uint ky = rem / p.KW;
const uint kx = rem - ky * p.KW;
const uint linear_idx = base_linear_idx + idx;
const uint iiw = ow * p.s0 + kx * p.d0 - p.p0;
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
if (linear_idx >= p.pelements) {
continue;
}
A_TYPE val = A_TYPE(0);
if (iih < p.IH && iiw < p.IW) {
val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw];
}
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
dst_addr.d = D_TYPE(values[idx]);
D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx));
out_ptr.d = D_TYPE(val);
#else
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
data_d[dst_row + chw_idx] = D_TYPE(val);
#endif
}
}
wg_x += gl_NumWorkGroups.x;
} while (wg_x * 512 < p.CHW);
}
void main() {
uint y = gl_GlobalInvocationID.y;
while (y < p.OH) {
uint ow = gl_GlobalInvocationID.y;
while (ow < p.OW) {
uint z = gl_GlobalInvocationID.z;
while (z < p.batch_IC) {
im2col(y, z);
while (z < p.OH_batch) {
im2col(ow, z);
z += gl_NumWorkGroups.z;
}
y += gl_NumWorkGroups.y;
ow += gl_NumWorkGroups.y;
}
}

View File

@@ -4,7 +4,6 @@
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "rte.glsl"
#include "types.glsl"
layout (push_constant) uniform parameter

View File

@@ -1,6 +1,5 @@
#version 450
#include "rte.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"

View File

@@ -501,6 +501,23 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
kvalues_mxfp4[vui2 & 0xF] * d);
buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d,
kvalues_mxfp4[vui2 >> 4] * d);
#elif defined(DATA_A_NVFP4)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
// lo and hi nibbles are 8 elements apart, which doesn't quite line up with
// how the thread mapping and buf_idx calculation works for other types.
const uint buf_idx = col * SHMEM_STRIDE + (row & 3) + (row & ~3) * 2;
const uint ib = idx / 16u;
const uint sub = (idx & 0xC) >> 2;
const uint iqs = (idx & 0xF) * 2;
const float d = ue4m3_to_fp32(data_a[ib].d[sub]) * 0.5;
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d,
kvalues_mxfp4[vui2 & 0xF] * d);
buf_a[buf_idx + 4] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d,
kvalues_mxfp4[vui2 >> 4] * d);
#endif
}

View File

@@ -8,7 +8,6 @@
#extension GL_KHR_shader_subgroup_basic : enable
#endif
#include "rte.glsl"
#include "types.glsl"
#include "utils.glsl"

View File

@@ -2,7 +2,6 @@
#extension GL_EXT_shader_16bit_storage : require
#include "rte.glsl"
#include "rope_params.glsl"
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;

View File

@@ -1,8 +1,6 @@
#if !defined(GGML_ROPE_PARAMS)
#define GGML_ROPE_PARAMS
#include "rte.glsl"
struct rope_params {
uint rope_mode;
uint nrows;

View File

@@ -1,5 +0,0 @@
#if RTE16
#extension GL_EXT_spirv_intrinsics : enable
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif // RTE16

View File

@@ -1,6 +1,5 @@
#version 450
#include "rte.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"

View File

@@ -1713,6 +1713,22 @@ struct block_mxfp4
#define A_TYPE block_mxfp4
#endif
#define QUANT_K_NVFP4 64
#define QUANT_R_NVFP4 1
struct block_nvfp4
{
uint8_t d[QUANT_K_NVFP4 / 16];
uint8_t qs[QUANT_K_NVFP4 / 2];
};
#if defined(DATA_A_NVFP4)
#define QUANT_K QUANT_K_NVFP4
#define QUANT_R QUANT_R_NVFP4
#define QUANT_AUXF 1
#define A_TYPE block_nvfp4
#endif
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
const int8_t kvalues_iq4nl_const[16] = {
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
@@ -1732,7 +1748,7 @@ void init_iq_shmem(uvec3 wgsize)
}
#endif
#if defined(DATA_A_MXFP4)
#if defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4)
const int8_t kvalues_mxfp4_const[16] = {
int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
@@ -1740,6 +1756,24 @@ const int8_t kvalues_mxfp4_const[16] = {
shared int8_t kvalues_mxfp4[16];
#if defined(DATA_A_NVFP4)
// UE4M3 scale in NVFP4 blocks use only 7 bits; sign (bit 7) is always zero.
shared float ue4m3_fp32_lut[128];
float ue4m3_to_fp32_build(uint u) {
if (u == 0u || u == 127u) {
return 0.0;
}
const uint exp = (u >> 3) & 15u;
const uint man = u & 7u;
if (exp == 0u) {
return float(man) * (1.0 / 512.0);
}
const uint bits = (exp + 120u) << 23 | (man << 20);
return uintBitsToFloat(bits);
}
#endif
#define NEEDS_INIT_IQ_SHMEM
void init_iq_shmem(uvec3 wgsize)
{
@@ -1747,6 +1781,11 @@ void init_iq_shmem(uvec3 wgsize)
for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) {
kvalues_mxfp4[i] = kvalues_mxfp4_const[i];
}
#if defined(DATA_A_NVFP4)
for (uint i = gl_LocalInvocationIndex.x; i < 128u; i += wgsize.x) {
ue4m3_fp32_lut[i] = ue4m3_to_fp32_build(i);
}
#endif
barrier();
}
#endif
@@ -1783,6 +1822,12 @@ float e8m0_to_fp32(uint8_t x) {
return uintBitsToFloat(bits);
}
#if defined(DATA_A_NVFP4)
float ue4m3_to_fp32(uint8_t x) {
return ue4m3_fp32_lut[uint(x)];
}
#endif
#if BDA
#extension GL_EXT_buffer_reference : enable

View File

@@ -66,6 +66,7 @@ const std::vector<std::string> type_names = {
"iq4_xs",
"iq4_nl",
"mxfp4",
"nvfp4",
"bf16",
};
@@ -556,7 +557,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
std::string load_vec_quant = "2";
if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4"))
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4") || (tname == "nvfp4"))
load_vec_quant = "4";
if (tname == "bf16") {
@@ -744,7 +745,7 @@ void process_shaders() {
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}}));
string_to_spv("rms_norm_mul_rope_f32_f16", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
@@ -768,15 +769,12 @@ void process_shaders() {
for (std::string t : {"q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}
for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}
auto get_type_str = [](bool f16) {
@@ -793,12 +791,10 @@ void process_shaders() {
for (auto src0_f16 : {false, true}) {
for (auto src1_f16 : {false, true}) {
for (auto dst_f16 : {false, true}) {
for (auto rte : {false, true}) {
auto source = op == "add_rms" ? std::string("add") : op;
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
auto add_rms = op == "add_rms" ? "1" : "0";
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
}
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , add_rms}});
}
}
}
@@ -846,14 +842,11 @@ void process_shaders() {
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
for (auto rte : {false, true}) {
std::string suffix = rte ? "_rte" : "";
string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}});
string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
}
string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
@@ -907,21 +900,18 @@ void process_shaders() {
string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
for (auto rte : {false, true}) {
std::string suffix = rte ? "_rte" : "";
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
}
string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("swiglu_oai_f16", "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("swiglu_oai_f32", "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -941,25 +931,18 @@ void process_shaders() {
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
@@ -982,7 +965,6 @@ void process_shaders() {
std::string bda_def = bda ? "1" : "0";
string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
}
}
@@ -1035,8 +1017,8 @@ void process_shaders() {
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "0"}});
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "1"}});
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
@@ -1089,8 +1071,8 @@ void write_output_files() {
std::string suffixes[2] = {"_f32", "_f16"};
for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
hdr << "extern const void * " << op << "_data[2][2][2];\n";
hdr << "extern const uint64_t " << op << "_len[2][2][2];\n";
std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp";
if (basename(input_filepath) != op_file) {
@@ -1098,8 +1080,8 @@ void write_output_files() {
}
std::stringstream data = make_generic_stringstream();
std::stringstream len = make_generic_stringstream();
data << "const void * " << op << "_data[2][2][2][2] = ";
len << "const uint64_t " << op << "_len[2][2][2][2] = ";
data << "const void * " << op << "_data[2][2][2] = ";
len << "const uint64_t " << op << "_len[2][2][2] = ";
for (uint32_t t0 = 0; t0 < 2; ++t0) {
if (t0 == 0) {
data << "{";
@@ -1115,20 +1097,10 @@ void write_output_files() {
data << "{";
len << "{";
}
for (uint32_t rte = 0; rte < 2; ++rte) {
if (rte == 0) {
data << "{";
len << "{";
}
data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
data << "_data,";
len << "_len,";
if (rte == 1) {
data << "}, ";
len << "}, ";
}
}
data << op << suffixes[t0] << suffixes[t1] << suffixes[t2];
len << op << suffixes[t0] << suffixes[t1] << suffixes[t2];
data << "_data,";
len << "_len,";
if (t2 == 1) {
data << "}, ";
len << "}, ";

File diff suppressed because it is too large Load Diff

View File

@@ -9,42 +9,65 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
#endif
#ifdef U32_DEQUANT_HELPERS
fn load_u16_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> u32 {
let word = buf[byte_offset / 4];
let shift = (byte_offset & 0x2) * 8;
return (word >> shift) & 0xFFFF;
#ifdef DECLARE_BYTE_LOADERS_SRC
fn load_u16_at_src(byte_offset: u32) -> u32 {
let word = src[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_u32_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4;
let shift = (byte_offset & 0x3) * 8;
let lo = buf[word_idx];
let hi = buf[word_idx + 1];
let shifted = (lo >> shift) | (hi << (32 - shift));
return select(shifted, lo, shift == 0);
fn load_u32_at_src(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 0x3u) * 8u;
let lo = src[word_idx];
let hi = src[word_idx + 1u];
let shifted = (lo >> shift) | (hi << (32u - shift));
return select(shifted, lo, shift == 0u);
}
fn load_f16_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at(buf, byte_offset));
fn load_f16_at_src(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at_src(byte_offset));
return f16(packed[0]);
}
fn load_f16_as_f32_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> f32 {
let word = buf[byte_offset / 4];
let shift = (byte_offset & 0x2) * 8;
let d_bits = (word >> shift) & 0xFFFF;
fn load_f16_as_f32_at_src(byte_offset: u32) -> f32 {
let word = src[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
let d_bits = (word >> shift) & 0xFFFFu;
return unpack2x16float(d_bits)[0];
}
#endif
#ifdef DECLARE_BYTE_LOADERS_SRC0
fn load_u16_at_src0(byte_offset: u32) -> u32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_u32_at_src0(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 0x3u) * 8u;
let lo = src0[word_idx];
let hi = src0[word_idx + 1u];
let shifted = (lo >> shift) | (hi << (32u - shift));
return select(shifted, lo, shift == 0u);
}
fn load_f16_at_src0(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at_src0(byte_offset));
return f16(packed[0]);
}
fn load_f16_as_f32_at_src0(byte_offset: u32) -> f32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
let d_bits = (word >> shift) & 0xFFFFu;
return unpack2x16float(d_bits)[0];
}
#endif
#endif
#ifdef Q4_1_T

View File

@@ -1,6 +1,8 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC
#include "common_decls.tmpl"
#ifdef F32_VEC
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
@@ -28,10 +30,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
for (var j: u32 = 0u; j < 4; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@@ -66,11 +68,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q5_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let qh_packed = load_u32_at(&src, block_byte_base + 2);
let d = load_f16_as_f32_at_src(block_byte_base);
let qh_packed = load_u32_at_src(block_byte_base + 2);
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 6 + j * 4;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
@@ -113,10 +115,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q8_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
for (var j: u32 = 0u; j < 8u; j++) {
let q_byte_offset = block_byte_base + 2u + j * 4u;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);
for (var k: u32 = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@@ -162,16 +164,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
// Bytes 108-109: f16 scale 'd'
let d = load_f16_as_f32_at(&src, block_byte_base + 108);
let d = load_f16_as_f32_at_src(block_byte_base + 108);
// Bytes 96-107: 12 bytes of scales (3 u32s)
let kmask1: u32 = 0x03030303;
let kmask2: u32 = 0x0f0f0f0f;
var scale_vals: array<u32, 4>;
scale_vals[0] = load_u32_at(&src, block_byte_base + 96);
scale_vals[1] = load_u32_at(&src, block_byte_base + 100);
scale_vals[2] = load_u32_at(&src, block_byte_base + 104);
scale_vals[0] = load_u32_at_src(block_byte_base + 96);
scale_vals[1] = load_u32_at_src(block_byte_base + 100);
scale_vals[2] = load_u32_at_src(block_byte_base + 104);
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
@@ -182,13 +184,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
// Bytes 0-31: 32 bytes of hmask (8 u32s)
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
hmask_vals[i] = load_u32_at_src(block_byte_base + i * 4);
}
// Bytes 32-95: 64 bytes of qs (16 u32s)
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16; i++) {
qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4);
qs_vals[i] = load_u32_at_src(block_byte_base + 32 + i * 4);
}
var dst_i = dst_base + offset * 256;
@@ -286,24 +288,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes
// Bytes 208-209: f16 scale 'd'
let d = load_f16_as_f32_at(&src, block_byte_base + 208);
let d = load_f16_as_f32_at_src(block_byte_base + 208);
// Bytes 0-127: 128 bytes of ql (32 u32s)
var ql_vals: array<u32, 32>;
for (var i: u32 = 0; i < 32; i++) {
ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
ql_vals[i] = load_u32_at_src(block_byte_base + i * 4);
}
// Bytes 128-191: 64 bytes of qh (16 u32s)
var qh_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16u; i++) {
qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u);
qh_vals[i] = load_u32_at_src(block_byte_base + 128 + i * 4u);
}
// Bytes 192-207: 16 bytes of scales (4 u32s)
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4);
scale_vals[i] = load_u32_at_src(block_byte_base + 192 + i * 4);
}
var dst_i = dst_base + offset * 256;
@@ -345,13 +347,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let aux0_offset = block_byte_base + 2 + ib * 2;
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
let aux0 = load_u32_at(&src, aux0_offset);
let aux1 = load_u32_at(&src, aux1_offset);
let aux0 = load_u32_at_src(aux0_offset);
let aux1 = load_u32_at_src(aux1_offset);
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
for (var l: u32 = 0; l < 4; l++) {
let ig = get_byte(aux0, l) * 8;
@@ -373,12 +375,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
var scale_vals = array<u32, 2>(
load_u32_at(&src, block_byte_base + 66),
load_u32_at(&src, block_byte_base + 70)
load_u32_at_src(block_byte_base + 66),
load_u32_at_src(block_byte_base + 70)
);
for (var ib: u32 = 0; ib < 32; ib += 4) {
@@ -389,7 +391,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
);
for (var l: u32 = 0; l < 4; l++) {
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF;
let qs_val = load_u32_at_src(qs_offset) & 0xFFFF;
let ig = (qs_val & 511) * 8;
let is = qs_val >> 9;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
@@ -408,21 +410,21 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
var qs_vals : array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
qs_vals[i] = load_u32_at_src(block_byte_base + 2 + i * 4);
}
var qh_vals: array<u32, 2>;
qh_vals[0] = load_u32_at(&src, block_byte_base + 66);
qh_vals[1] = load_u32_at(&src, block_byte_base + 70);
qh_vals[0] = load_u32_at_src(block_byte_base + 66);
qh_vals[1] = load_u32_at_src(block_byte_base + 70);
var scale_vals: array<u32, 2>;
scale_vals[0] = load_u32_at(&src, block_byte_base + 74);
scale_vals[1] = load_u32_at(&src, block_byte_base + 78);
scale_vals[0] = load_u32_at_src(block_byte_base + 74);
scale_vals[1] = load_u32_at_src(block_byte_base + 78);
for (var ib: u32 = 0; ib < 8; ib ++) {
let s = get_byte(scale_vals[ib / 4], ib % 4);
@@ -450,16 +452,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ3_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 16; ib += 2) {
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
let sc_sign = load_u32_at(&src, sc_sign_offset);
let sc_sign = load_u32_at_src(sc_sign_offset);
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
for (var l: u32 = 0; l < 4; l++) {
let is = (sc_sign >> (7 * l)) & 127;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0);
let ig2 = get_byte(ig_val, 1);
for (var j: u32 = 0; j < 4; j++) {
@@ -480,20 +482,20 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ3_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
var qh_vals = array<u32, 2>(
load_u32_at(&src, block_byte_base + 66),
load_u32_at(&src, block_byte_base + 70)
load_u32_at_src(block_byte_base + 66),
load_u32_at_src(block_byte_base + 70)
);
var sign_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4);
sign_vals[i] = load_u32_at_src(block_byte_base + 74 + i * 4);
}
var scale_vals = load_u32_at(&src, block_byte_base + 106);
var scale_vals = load_u32_at_src(block_byte_base + 106);
for (var ib: u32 = 0; ib < 4; ib++) {
let s = get_byte(scale_vals, ib);
@@ -507,7 +509,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let sign_w = sign_vals[ib * 2 + k];
for (var l: u32 = 0; l < 4; l++) {
let signs = get_byte(sign_w, l);
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
for (var j: u32 = 0; j < 4; j++) {
@@ -529,13 +531,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ1_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 8; ib++) {
let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF;
let qh = load_u32_at_src(block_byte_base + 34 + ib * 2) & 0xFFFF;
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4);
let qs_w = load_u32_at_src(block_byte_base + 2 + ib * 4);
for (var l: u32 = 0; l < 4; l++) {
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
for (var j: u32 = 0; j < 8; j++) {
@@ -596,11 +598,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ4_NL
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 32;
var qs: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
qs[i] = load_u32_at_src(block_byte_base + 2 + i * 4);
}
for (var j: u32 = 0; j < 16; j++) {
let qsb = get_byte(qs[j / 4], j % 4);

View File

@@ -1,7 +1,9 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#ifdef FLOAT
const BLOCK_SIZE = 1u;
@@ -21,11 +23,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q4_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
@@ -63,12 +65,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q5_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
let qh_packed = load_u32_at(&src0, block_byte_base + 2);
let qh_packed = load_u32_at_src0(block_byte_base + 2);
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 6 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
@@ -110,11 +112,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q8_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 8; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@@ -184,7 +186,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
// Bytes 108-109: f16 scale 'd'
let d = load_f16_as_f32_at(&src0, block_byte_base + 108);
let d = load_f16_as_f32_at_src0(block_byte_base + 108);
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
// and 2-bits from the last 4 bytes
@@ -192,9 +194,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let kmask1: u32 = 0x03030303;
let kmask2: u32 = 0x0f0f0f0f;
var scale_vals: array<u32, 4>;
scale_vals[0] = load_u32_at(&src0, block_byte_base + 96);
scale_vals[1] = load_u32_at(&src0, block_byte_base + 100);
scale_vals[2] = load_u32_at(&src0, block_byte_base + 104);
scale_vals[0] = load_u32_at_src0(block_byte_base + 96);
scale_vals[1] = load_u32_at_src0(block_byte_base + 100);
scale_vals[2] = load_u32_at_src0(block_byte_base + 104);
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
@@ -205,13 +207,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
// Bytes 0-31: 32 bytes of hmask (8 u32s)
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
}
// Bytes 32-95: 64 bytes of qs (16 u32s)
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16; i++) {
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4);
qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4);
}
var sum = 0.0;
@@ -313,24 +315,24 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes
// Bytes 208-209: f16 scale 'd'
let d = load_f16_as_f32_at(&src0, block_byte_base + 208);
let d = load_f16_as_f32_at_src0(block_byte_base + 208);
// Bytes 0-127: 128 bytes of ql (32 u32s)
var ql_vals: array<u32, 32>;
for (var i: u32 = 0; i < 32; i++) {
ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
}
// Bytes 128-191: 64 bytes of qh (16 u32s)
var qh_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4);
qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4);
}
// Bytes 192-207: 16 bytes of scales (4 u32s)
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4);
scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4);
}
var sum = 0.0;
@@ -374,14 +376,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let aux0_offset = block_byte_base + 2 + ib * 2;
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
let aux0 = load_u32_at(&src0, aux0_offset);
let aux1 = load_u32_at(&src0, aux1_offset);
let aux0 = load_u32_at_src0(aux0_offset);
let aux1 = load_u32_at_src0(aux1_offset);
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
for (var l: u32 = 0; l < 4; l++) {
let ig = get_byte(aux0, l) * 8;
@@ -402,12 +404,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var scale_vals = array<u32, 2>(
load_u32_at(&src0, block_byte_base + 66),
load_u32_at(&src0, block_byte_base + 70)
load_u32_at_src0(block_byte_base + 66),
load_u32_at_src0(block_byte_base + 70)
);
var sum = 0.0;
@@ -419,7 +421,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
);
for (var l: u32 = 0; l < 4; l++) {
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF;
let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF;
let ig = (qs_val & 511) * 8;
let is = qs_val >> 9;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
@@ -439,21 +441,21 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var qs_vals : array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
}
var qh_vals: array<u32, 2>;
qh_vals[0] = load_u32_at(&src0, block_byte_base + 66);
qh_vals[1] = load_u32_at(&src0, block_byte_base + 70);
qh_vals[0] = load_u32_at_src0(block_byte_base + 66);
qh_vals[1] = load_u32_at_src0(block_byte_base + 70);
var scale_vals: array<u32, 2>;
scale_vals[0] = load_u32_at(&src0, block_byte_base + 74);
scale_vals[1] = load_u32_at(&src0, block_byte_base + 78);
scale_vals[0] = load_u32_at_src0(block_byte_base + 74);
scale_vals[1] = load_u32_at_src0(block_byte_base + 78);
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib ++) {
@@ -483,17 +485,17 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ3_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 16; ib += 2) {
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
let sc_sign = load_u32_at(&src0, sc_sign_offset);
let sc_sign = load_u32_at_src0(sc_sign_offset);
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
for (var l: u32 = 0; l < 4; l++) {
let is = (sc_sign >> (7 * l)) & 127;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0);
let ig2 = get_byte(ig_val, 1);
for (var j: u32 = 0; j < 4; j++) {
@@ -515,20 +517,20 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ3_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var qh_vals = array<u32, 2>(
load_u32_at(&src0, block_byte_base + 66),
load_u32_at(&src0, block_byte_base + 70)
load_u32_at_src0(block_byte_base + 66),
load_u32_at_src0(block_byte_base + 70)
);
var sign_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4);
sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4);
}
var scale_vals = load_u32_at(&src0, block_byte_base + 106);
var scale_vals = load_u32_at_src0(block_byte_base + 106);
var sum = 0.0;
for (var ib: u32 = 0; ib < 4; ib++) {
@@ -543,7 +545,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let sign_w = sign_vals[ib * 2 + k];
for (var l: u32 = 0; l < 4; l++) {
let signs = get_byte(sign_w, l);
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
for (var j: u32 = 0; j < 4; j++) {
@@ -566,14 +568,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ1_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib++) {
let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF;
let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF;
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4);
let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4);
for (var l: u32 = 0; l < 4; l++) {
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
for (var j: u32 = 0; j < 8; j++) {
@@ -638,12 +640,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ4_NL
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 32;
var sum = 0.0;
var qs: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
}
for (var j: u32 = 0; j < 16; j++) {
let qsb = get_byte(qs[j / 4], j % 4);

View File

@@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let d = load_f16_at_src0(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f16(q_byte & 0xF) * d + m;
@@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let qh_packed = load_u32_at_src0(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
let qh_packed = load_u32_at_src0(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let d = load_f16_at_src0(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base + 80u);
let dmin = load_f16_at(&src0, block_byte_base + 82u);
let d = load_f16_at_src0(block_byte_base + 80u);
let dmin = load_f16_at_src0(block_byte_base + 82u);
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
@@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let is = k_in_block / 16u;
let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u));
let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u));
let sc = get_byte(sc_packed, is % 4u);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
let q_idx = q_b_idx + k + l;
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
@@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base + 108u);
let d = load_f16_at_src0(block_byte_base + 108u);
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
@@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i);
scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i);
}
var tmp: u32 = scale_vals[2];
@@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i);
hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i);
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i);
qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i);
}
let half = k_in_block / 128u; // 0 or 1
@@ -499,14 +499,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let dmin = load_f16_at(&src0, block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
}
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// Map k_in_block to loop structure:
// Outer loop over 64-element groups (alternating q_b_idx)
@@ -523,15 +517,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
var sc: u32;
var mn: u32;
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
@@ -541,7 +537,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
@@ -575,14 +571,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let dmin = load_f16_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
}
// The original loop processes elements in groups of 64
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
@@ -603,15 +594,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
var sc: u32;
var mn: u32;
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
@@ -621,11 +614,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u));
let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u));
let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u));
let qh_byte = get_byte(qh_packed, l % 4u);
@@ -673,17 +666,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat);
let ql13 = load_u32_at_src0(block_byte_base + ql13_flat);
let ql13_b = get_byte(ql13, 0u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat);
let ql24 = load_u32_at_src0(block_byte_base + ql24_flat);
let ql24_b = get_byte(ql24, 0u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat);
let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat);
let qh_b = get_byte(qh, 0u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
@@ -694,10 +687,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx);
let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx);
let sc_val = get_byte_i32(sc, 0u);
let d = load_f16_at(&src0, block_byte_base + 208u);
let d = load_f16_at_src0(block_byte_base + 208u);
var q_val: f16;
if (quarter == 0u) {

View File

@@ -1,6 +1,8 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#ifdef VEC

View File

@@ -1,17 +1,19 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#ifdef VEC
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
fn store_val(acc: array<array<f32, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
return vec4<f32>(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]);
}
#endif
#ifdef SCALAR
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
return f32(acc[tm][tn]);
fn store_val(acc: array<array<f32, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
return acc[tm][tn];
}
#endif
@@ -98,7 +100,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
var acc: array<array<f16, TILE_N>, TILE_M>;
var acc: array<array<f32, TILE_N>, TILE_M>;
for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
@@ -122,7 +124,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let src1_idx = src1_n * TILE_K + k_inner;
let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
for (var tm = 0u; tm < TILE_M; tm++) {
acc[tm][tn] += src0_tile[tm] * src1_val;
acc[tm][tn] += f32(src0_tile[tm]) * f32(src1_val);
}
}
}

View File

@@ -3,9 +3,14 @@ enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
// TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs.
// See https://github.com/ggml-org/llama.cpp/issues/21602
#ifdef VEC
fn store_dst(shmem_idx: u32, dst_idx: u32) {
dst[dst_idx] = vec4<f32>(
@@ -193,4 +198,3 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
}
}
}

View File

@@ -1,7 +1,9 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#ifdef VEC
#define VEC_SIZE 4
@@ -65,10 +67,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let d = f32(load_f16_at_src0(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@@ -98,11 +100,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let m = f32(load_f16_at(&src0, block_byte_base + 2u));
let d = f32(load_f16_at_src0(block_byte_base));
let m = f32(load_f16_at_src0(block_byte_base + 2u));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
@@ -132,12 +134,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
let d = f32(load_f16_at_src0(block_byte_base));
let qh_packed = load_u32_at_src0(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@@ -176,13 +178,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let m = load_f16_at(&src0, block_byte_base + 2u);
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
let d = f32(load_f16_at_src0(block_byte_base));
let m = load_f16_at_src0(block_byte_base + 2u);
let qh_packed = load_u32_at_src0(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@@ -221,11 +223,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let d = f32(load_f16_at_src0(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@@ -254,12 +256,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let m = load_f16_at(&src0, block_byte_base + 2u);
let d = f32(load_f16_at_src0(block_byte_base));
let m = load_f16_at_src0(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + f32(m);
@@ -309,13 +311,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = ix; i < nb; i += 2u) {
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at(&src0, bbase + 208u));
let d = f32(load_f16_at_src0(bbase + 208u));
let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l);
let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u);
let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h);
let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte);
let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u);
let ql1_u32 = load_u32_at_src0(bbase + q_offset_l);
let ql2_u32 = load_u32_at_src0(bbase + q_offset_l + 32u);
let qh_u32 = load_u32_at_src0(bbase + 128u + q_offset_h);
let sc_u32_0 = load_u32_at_src0(bbase + sc_base_byte);
let sc_u32_1 = load_u32_at_src0(bbase + sc_base_byte + 4u);
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);

View File

@@ -147,15 +147,12 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-9.010913, 9.010913)));
#endif
#ifdef XIELU
let val = f32(src[params.offset_src + src_idx]);
let res =
select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -
src[params.offset_src + src_idx]) *
TYPE(params.alpha_n) +
TYPE(params.beta) * src[params.offset_src + src_idx],
TYPE(params.alpha_p) * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] +
TYPE(params.beta) * src[params.offset_src + src_idx],
src[params.offset_src + src_idx] > 0.0);
TYPE(select(
((exp(min(val, params.eps)) - 1.0) - val) * params.alpha_n + params.beta * val,
params.alpha_p * val * val + params.beta * val,
val > 0.0));
#endif
#ifdef SOFTPLUS
let src_f32 = f32(src[params.offset_src + src_idx]);

Some files were not shown because too many files have changed in this diff Show More