Compare commits

...

30 Commits
b8141 ... b8171

Author SHA1 Message Date
Neo Zhang
c17dce4f5c replace the magic nunber 768 by max work group size to support iGPU (#19920)
Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>
2026-02-27 09:26:07 +08:00
Vishal Singh
88cf781f51 ggml-zendnn: update code for latest ZenDNN API (#19923)
- adapt ggml-zendnn.cpp to the new lowoha::matmul interface
- update the ZenDNN git tag in CMake to the latest release (ZenDNN‑2026‑WW08)
- add static lib support in CMake
2026-02-27 08:43:41 +08:00
Adrien Gallouët
4e76d24f28 ggml : fix AMX and add batched support (#19925)
llama-perplexity -hf ggml-org/Qwen3-0.6B-GGUF:Q4_0 -f wikitext-2-raw/wiki.test.raw -c 2048 -b 2048 --chunks 2

before this commit:

```
perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1
perplexity: 2.31 seconds per pass - ETA 0.07 minutes
[1]17.3868,[2]22.2199,
Final estimate: PPL = 22.2199 +/- 1.59692

llama_perf_context_print:        load time =     878.56 ms
llama_perf_context_print: prompt eval time =    2037.82 ms /  4096 tokens (    0.50 ms per token,  2009.99 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    6403.17 ms /  4097 tokens
llama_perf_context_print:    graphs reused =          0
llama_memory_breakdown_print: | memory breakdown [MiB] | total   free    self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - Host               |                  845 =   318 +     224 +     302                |
llama_memory_breakdown_print: |   - CPU_REPACK         |                  288 =   288 +       0 +       0                |
llama_memory_breakdown_print: |   - AMX                |                   31 =    31 +       0 +       0                |
```

after this commit:

```
perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1
perplexity: 1.98 seconds per pass - ETA 0.05 minutes
[1]17.2005,[2]21.8220,
Final estimate: PPL = 21.8220 +/- 1.56485

llama_perf_context_print:        load time =     719.23 ms
llama_perf_context_print: prompt eval time =    1676.23 ms /  4096 tokens (    0.41 ms per token,  2443.58 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    4258.74 ms /  4097 tokens
llama_perf_context_print:    graphs reused =          0
llama_memory_breakdown_print: | memory breakdown [MiB] | total   free    self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - Host               |                  845 =   318 +     224 +     302                |
llama_memory_breakdown_print: |   - AMX                |                  319 =   319 +       0 +       0                |
```
(no more CPU_REPACK)

after this commit, disabling amx:

```
perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1
perplexity: 2.34 seconds per pass - ETA 0.07 minutes
[1]17.2005,[2]21.8220,
Final estimate: PPL = 21.8220 +/- 1.56485

llama_perf_context_print:        load time =     841.91 ms
llama_perf_context_print: prompt eval time =    2057.28 ms /  4096 tokens (    0.50 ms per token,  1990.98 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    6454.51 ms /  4097 tokens
llama_perf_context_print:    graphs reused =          0
llama_memory_breakdown_print: | memory breakdown [MiB] | total   free    self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - Host               |                  845 =   318 +     224 +     302                |
llama_memory_breakdown_print: |   - CPU_REPACK         |                  319 =   319 +       0 +       0                |
```
=> same perplexity.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-02-26 21:39:11 +01:00
Ruben Ortlam
723c71064d vulkan: fix fp16 Flash Attention on Windows AMD RDNA2 and below (#19921) 2026-02-26 19:11:04 +01:00
Georgi Gerganov
37964f44f9 mtmd : fix padding of n_tokens (#19930) 2026-02-26 18:39:49 +02:00
Georgi Gerganov
01cd448b8c server : fix ctx checkpoint restore logic (#19924) 2026-02-26 18:20:16 +02:00
Georgi Gerganov
99bd67c9b2 kv-cache : fix can_shift() check to take into account M-RoPE (#19928) 2026-02-26 18:08:54 +02:00
Aman Gupta
b68d75165a llama: Add option to merge gate and exp weights (#19139)
* llama: Add option to merge gate and exp weights

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* update constants.py

* add gate_up for the all MoE models

* convert: simplify merge tensor condition

* update constants.py

* reduce number of models, add create_tensor_gate_up helper

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-26 21:01:08 +08:00
Kevin Pouget
ffaafde16f ggml-virtgpu: improve the reliability of the code (#19846)
* ggml-virtgpu-backend: validate the consistency of the received objects

This patch adds consistency checks in the
ggml-virtgpu-backend (running on the host side) to ensure that the
data received from the guest is consistent (valid pointers, valid
sizes and offsets).

* ggml-virtgpu-backend: add fallback/skips for optional ggml backend methods

```
  1. bck->iface.synchronize(bck)
  2. buft->iface.get_alloc_size(buft, op)
  3. buft->iface.get_max_size(buft)
```

these three methods are optional in the GGML interface. `get_max_size`
was already properly defaulted, but `backend sychronize` and `butf
get_max_size` would have segfaulted the backend if not implemented.

* ggml-virtgpu-backend: fix log format missing argument

* ggml-virtgpu-backend: improve the abort message

* ggml-virtgpu-backend: more safety checks

* ggml-virtgpu-backend: new error code

* ggml-virtgpu-backend: initialize all the error codes

* ggml-virtgpu: add a missing comment generated by the code generator

* ggml-virtgpu: add the '[virtgpu]' prefix to the device/buffer names

* ggml-virtgpu: apir_device_buffer_from_ptr: improve the error message

* ggml-virtgpu: shared: make it match the latest api_remoting.h of Virglrenderer APIR

(still unmerged)

* ggml-virtgpu: update the code generator to have dispatch_command_name in a host/guest shared file

* ggml-virtgpu: REMOTE_CALL: fail if the backend returns an error

* docs/backend/VirtGPU.md: indicate that the RAM+VRAM size is limed to 64 GB with libkrun

* ggml-virtgpu: turn off clang-format header ordering for some of the files

Compilation breaks when ordered alphabetically.

* ggml-virtgpu: clang-format

* ggml-virtgpu/backend/shared/api_remoting: better comments for the APIR return codes
2026-02-26 20:00:57 +08:00
drrros
efba35a860 server: fix load-on-startup not respected in ini file (#19897)
Co-authored-by: Roman Marchenko <r.marchenko@ideco.ru>
2026-02-26 12:32:31 +01:00
Eric Zhang
9b62913b40 jinja : correct default size for string slices (#19913) 2026-02-26 12:28:09 +01:00
Maximilian Werk
66287bdaac model : add Jina Embeddings v5 Nano (partial EuroBERT) support (#19826)
* WIP: Add EuroBERT support with autoformatting changes

This commit includes:
- EuroBERT model implementation for GGUF conversion
- C++ backend support for EuroBERT architecture
- Unintended autoformatting changes to Python files

Saving before reverting formatting-only changes.

* feat: add back eos assert when not last token pooling

* feat: removed duplicated code and cleanup

* feat: removed not working architectures and unnecessary check

* fix: typo

* fix: dynamic pooling config

* feat: added an example model for eurobert

* feat: proper llama-vocab implementation for jina-v5

* fix: removed unnecessary comments
2026-02-26 12:14:09 +01:00
Georgi Gerganov
1ca3d1de15 gguf : avoid too many file size calls (#19919) 2026-02-26 12:46:32 +02:00
yggdrasil75
bd72300591 server : fix typo in server README.md (#19900)
fix typo
2026-02-26 11:26:16 +01:00
Neo Zhang
2943210c1e support permuted, remove check s0/s10 (#19889)
Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>
2026-02-26 10:27:20 +08:00
Jeff Bolz
3769fe6eb7 vulkan: check for memory overlap before doing fusion (#19768)
* vulkan: check for memory overlap before doing fusion

* Update ggml/src/ggml-vulkan/ggml-vulkan.cpp

* address feedback
2026-02-25 18:25:38 +01:00
ddh0
832aa94762 common : add more aliases for sampler CLI params (#19797)
* common : add more aliases for sampler CLI params
2026-02-25 16:34:25 +01:00
Slobodan Josic
3af34b9ff5 ci : update the ROCm/HIP toolchain versions [no ci] (#19891)
* [HIP] Update ROCm build container to rocm/dev-ubuntu-22.04:7.2 and HIP_SDK to 26.Q1

* revert container version

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-25 15:54:49 +01:00
Georgi Gerganov
f20469d919 server : enable multi-modal prompt caching (#19877) 2026-02-25 15:15:42 +02:00
Georgi Gerganov
d7d826b3c1 server : support multi-modal context checkpoints (#19849)
* Modify llama-memory-hybrid-iswa.cpp

* Modify llama-memory-recurrent.cpp

* Modify server-common.cpp

* Modify server-common.h

* Modify server-context.cpp

* Modify server-task.h

* Added comment to llama-memory-hybrid-iswa.cpp

* Remove comment from server-context.cpp

* Stylistic fix server-context.cpp

* Fix an issue when seqrm isn't called in server-context.cpp

* cont : alternative impl

* cont : cleanup

* cont : n_tokens -> int64_t

---------

Co-authored-by: timkhronos <timkhronos@gmail.com>
2026-02-25 15:14:27 +02:00
Xuan-Son Nguyen
c747294b2d scripts: update corpus of compare-logprobs (#19326)
* scripts: update corpus of compare-logprobs

* fix
2026-02-25 12:57:34 +01:00
Mario Limonciello
8fdf269dad ci : update Windows ROCm build to 26.Q1 [no ci] (#19810)
* Update build command to build llama-* tools not just ggml-hip
* Update rocWMMA headers to 7.2
* Add GFX1150 target
* Correct library paths for AMD libraries in 26.Q1
2026-02-25 12:30:19 +01:00
Aldehir Rojas
a96a1120b4 gguf : fix ftell/fseek for Windows (#19870) 2026-02-25 06:58:11 +02:00
Georgi Gerganov
244641955f models : fix graph splits (#19866) 2026-02-25 00:01:13 +02:00
Pascal
47eb12b953 server: fix query params lost when proxying requests in multi-model router mode (#19854)
* server: fix query params lost when proxying requests in multi-model router mode

* server: re-encode query params using httplib::encode_query_component in proxy
2026-02-24 21:46:06 +01:00
Georgi Gerganov
418dea39ce ggml/gguf : prevent integer overflows (#19856)
* gguf : prevent integer overflow for ggml_context mem size

* ggml : fix int overflows in ggml_new_object()

* gguf : prevent string exhaustion

* gguf : prevent array elements exhaustion

* ggml : fix negative tensor type oob

* py : assert that alignment is non-zero power of 2

* ggml : check int overflow in ggml_new_tensor_impl and ggml_new_object

* gguf-py : error on duplicate keys when reading

* py : restore tensor_fields

* enforce proper alignment in add_custom_alignment

* gguf : better name

* gguf : fix ctx size for no_alloc == true

* gguf : minor print fix

* ggml : print values when overflow

* ggml : remove deprecated ggml_type_sizef()

* ggml : relax ggml_type asserts to debug-only

* gguf : add mem_size overflow test

* gguf : add file size check for arrays

* ggml : relax asseerts for ggml_get_type_traits()

* flake8 fix

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-24 20:17:11 +02:00
Tarek Dakhran
da426cb250 model : update label for LFM2-24B-A2B (#19848)
* model : Update label for LFM2-24B-A2B

```
❯ build/bin/llama-bench -m /data/playground/checkpoints/LFM2-24B-A2B-Preview-Q4_0.gguf,/data/playground/checkpoints/LFM2-8B-A1B-Q4_0.gguf -p 1 -n 0
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| lfm2moe 24B.A2B Q4_0           |  12.54 GiB |    23.84 B | CPU        |      10 |             pp1 |         30.35 ± 2.49 |
| lfm2moe 8B.A1B Q4_0            |   4.41 GiB |     8.34 B | CPU        |      10 |             pp1 |         49.24 ± 1.93 |
```

* Remove extra line
2026-02-24 14:27:42 +01:00
Radoslav Gerganov
c830f99cfa server : support max_completion_tokens request property (#19831)
"max_tokens" is deprectated in favor of "max_completion_tokens" which
sets the upper bound for reasoning+output token.

Closes: #13700
2026-02-24 10:30:00 +02:00
Ruben Ortlam
aa6f918c1c Vulkan Scalar Flash Attention Refactor (#19625)
* vulkan: allow using fp16 in scalar flash attention shader

* split rows inside of subgroups for faster synchronization

* use row_split when Br >= 4, change reductions to use shared memory if row_split == 1

* use f32 scalar FA if f16 is not supported by device

* fix amd workgroup size issue

* optimize masksh use

* add medium rows FA shader Br size

* fixes

* add padding to mask shmem buffer

* cache q values into registers for KQ

* fuse lf accumulation, pf and v accumulation into a loop

* stage K loads through shmem

* stage V loads through shmem

* only stage through shmem on Nvidia

* default to Bc 32

* also stage V through shmem when this is done for K

* dynamic subgroups for intel

* use vectorized stores

* use float_type for dequantize4 functions

* use smaller scalar rows size for smaller rows count

* relax flash attention split_k condition to allow non-gqa use

* use minimal subgroup size on Intel

* fix shmem support function

* fix rebase issues

* fixes

* Bc 4 for scalar FA is not a valid configuration

* Use wave32 on AMD RDNA for scalar FA

* add Intel shader core count lookup-table

* fix regressions

* device tuning

* tmpsh size fix

* fix editorconfig

* refactor fa tuning logic into a single place

* fix gqa opt logic

* fix block_rows with small n_rows

* amd tuning

* fix hsk=72/80 issue

* tuning

* allow condition skipping for column check

* use float16 for Of if available

* address feedback

* fix bad RDNA performance on head size <= 128 by limiting occupancy

* allow printing pipeline stats

* cleanup and fixes

* limit occupancy for GCN for small batch FA with large HSK

* disable f16 FA for GCN AMD GPUs on the proprietary driver
2026-02-24 08:35:48 +01:00
Jeff Bolz
8c2c0108dd vulkan: fix coopmat1 without bf16 support (#19793) 2026-02-24 07:48:32 +01:00
89 changed files with 2409 additions and 1205 deletions

View File

@@ -11,5 +11,5 @@ runs:
- name: Setup ROCm
uses: ./.github/actions/install-exe
with:
url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-WinSvr2022-For-HIP.exe
url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-Win11-For-HIP.exe
args: -install

View File

@@ -68,7 +68,7 @@ jobs:
env:
# Make sure this is in sync with build.yml
HIPSDK_INSTALLER_VERSION: "25.Q3"
HIPSDK_INSTALLER_VERSION: "26.Q1"
steps:
- name: Clone

View File

@@ -1175,10 +1175,8 @@ jobs:
runs-on: windows-2022
env:
# The ROCm version must correspond to the version used in the HIP SDK.
ROCM_VERSION: "6.4.2"
# Make sure this is in sync with build-cache.yml
HIPSDK_INSTALLER_VERSION: "25.Q3"
HIPSDK_INSTALLER_VERSION: "26.Q1"
steps:
- name: Clone
@@ -1188,7 +1186,7 @@ jobs:
- name: Grab rocWMMA package
id: grab_rocwmma
run: |
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/${{ env.ROCM_VERSION }}/pool/main/r/rocwmma-dev/rocwmma-dev_1.7.0.60402-120~24.04_amd64.deb"
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
7z x rocwmma.deb
7z x data.tar
@@ -1231,7 +1229,7 @@ jobs:
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/" `
-DCMAKE_BUILD_TYPE=Release `
-DLLAMA_BUILD_BORINGSSL=ON `
-DROCM_DIR="${env:HIP_PATH}" `

View File

@@ -616,13 +616,13 @@ jobs:
runs-on: windows-2022
env:
HIPSDK_INSTALLER_VERSION: "25.Q3"
HIPSDK_INSTALLER_VERSION: "26.Q1"
strategy:
matrix:
include:
- name: "radeon"
gpu_targets: "gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
gpu_targets: "gfx1150;gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
steps:
- name: Clone
@@ -632,7 +632,7 @@ jobs:
- name: Grab rocWMMA package
id: grab_rocwmma
run: |
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.0.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.0.0.70001-42~24.04_amd64.deb"
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
7z x rocwmma.deb
7z x data.tar
@@ -655,7 +655,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "Downloading AMD HIP SDK Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-Win11-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP SDK"
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
$completed = $proc.WaitForExit(600000)
@@ -689,20 +689,20 @@ jobs:
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.0.1/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
-DCMAKE_BUILD_TYPE=Release `
-DGGML_BACKEND_DL=ON `
-DGGML_NATIVE=OFF `
-DGGML_CPU=OFF `
-DAMDGPU_TARGETS="${{ matrix.gpu_targets }}" `
-DGPU_TARGETS="${{ matrix.gpu_targets }}" `
-DGGML_HIP_ROCWMMA_FATTN=ON `
-DGGML_HIP=ON `
-DLLAMA_BUILD_BORINGSSL=ON
cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS}
md "build\bin\rocblas\library\"
md "build\bin\hipblaslt\library"
cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\hipblaslt.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\libhipblas.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\libhipblaslt.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\"
cp "${env:HIP_PATH}\bin\hipblaslt\library\*" "build\bin\hipblaslt\library\"

View File

@@ -1578,7 +1578,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_sparam());
add_opt(common_arg(
{"--temp"}, "N",
{"--temp", "--temperature"}, "N",
string_format("temperature (default: %.2f)", (double)params.sampling.temp),
[](common_params & params, const std::string & value) {
params.sampling.temp = std::stof(value);
@@ -1611,7 +1611,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_sparam());
add_opt(common_arg(
{"--top-nsigma"}, "N",
{"--top-nsigma", "--top-n-sigma"}, "N",
string_format("top-n-sigma sampling (default: %.2f, -1.0 = disabled)", params.sampling.top_n_sigma),
[](common_params & params, const std::string & value) {
params.sampling.top_n_sigma = std::stof(value);
@@ -1634,7 +1634,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_sparam());
add_opt(common_arg(
{"--typical"}, "N",
{"--typical", "--typical-p"}, "N",
string_format("locally typical sampling, parameter p (default: %.2f, 1.0 = disabled)", (double)params.sampling.typ_p),
[](common_params & params, const std::string & value) {
params.sampling.typ_p = std::stof(value);

View File

@@ -721,6 +721,8 @@ value member_expression::execute_impl(context & ctx) {
int64_t arr_size = 0;
if (is_val<value_array>(object)) {
arr_size = object->as_array().size();
} else if (is_val<value_string>(object)) {
arr_size = object->as_string().length();
}
if (is_stmt<slice_expression>(this->property)) {

View File

@@ -116,7 +116,8 @@ class ModelBase:
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
disable_mistral_community_chat_template: bool = False,
sentence_transformers_dense_modules: bool = False):
sentence_transformers_dense_modules: bool = False,
fuse_gate_up_exps: bool = False):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
@@ -135,6 +136,9 @@ class ModelBase:
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
self.fuse_gate_up_exps = fuse_gate_up_exps
self._gate_exp_buffer: dict[int, Tensor] = {}
self._up_exp_buffer: dict[int, Tensor] = {}
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
self.metadata_override = metadata_override
@@ -512,8 +516,31 @@ class ModelBase:
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
return [(self.map_tensor_name(name), data_torch)]
new_name = self.map_tensor_name(name)
# Handle gate/up expert tensor fusion if enabled
if self.fuse_gate_up_exps and bid is not None:
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid):
self._gate_exp_buffer[bid] = data_torch
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
self._up_exp_buffer[bid] = data_torch
# Check if both gate and up are buffered for this layer
if bid in self._gate_exp_buffer and bid in self._up_exp_buffer:
gate_data = self._gate_exp_buffer.pop(bid)
up_data = self._up_exp_buffer.pop(bid)
# gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
fused_data = torch.cat([gate_data, up_data], dim=1)
fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
logger.info(f"Fused gate_exps and up_exps for layer {bid}")
return [(fused_name, fused_data)]
# If we buffered a gate/up tensor, wait for the other
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid) or \
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
return []
return [(new_name, data_torch)]
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del name, new_name, bid, n_dims # unused
@@ -1148,6 +1175,9 @@ class TextModel(ModelBase):
if chkhsh == "27949a2493fc4a9f53f5b9b029c82689cfbe5d3a1929bb25e043089e28466de6":
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-de
res = "jina-v2-de"
if chkhsh == "a023e9fdc5a11f034d3ef515b92350e56fb2af1f66c6b6811a4444ea9bf8763d":
# ref: https://huggingface.co/jinaai/jina-embeddings-v5-text-nano
res = "jina-v5-nano"
if chkhsh == "c136ed14d01c2745d4f60a9596ae66800e2b61fa45643e72436041855ad4089d":
# ref: https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct
res = "smaug-bpe"
@@ -6125,6 +6155,32 @@ class NeoBert(BertModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("EuroBertModel", "JinaEmbeddingsV5Model")
class EuroBertModel(TextModel):
model_arch = gguf.MODEL_ARCH.EUROBERT
def set_vocab(self):
self.gguf_writer.add_add_bos_token(False)
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
# EuroBert is bidirectional (encoder)
self.gguf_writer.add_causal_attention(False)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self._try_set_pooling_type()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Strip "model." prefix from tensor names
if name.startswith("model."):
name = name[6:]
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
@@ -11913,6 +11969,11 @@ def parse_args() -> argparse.Namespace:
"Default these modules are not included.")
)
parser.add_argument(
"--fuse-gate-up-exps", action="store_true",
help="Fuse gate_exps and up_exps tensors into a single gate_up_exps tensor for MoE models.",
)
args = parser.parse_args()
if not args.print_supported_models and args.model is None:
parser.error("the following arguments are required: model")
@@ -12050,7 +12111,8 @@ def main() -> None:
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules,
fuse_gate_up_exps=args.fuse_gate_up_exps
)
if args.vocab_only:

View File

@@ -107,6 +107,7 @@ models = [
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
{"name": "jina-v5-nano", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v5-text-nano", },
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },

View File

@@ -152,7 +152,9 @@ Commands and data are serialized using a custom binary protocol with:
- **VM-specific**: Only works in virtual machines with virtio-gpu support
- **Host dependency**: Requires properly configured host-side backend
- **Latency**: Small overhead from VM escaping for each operation
- **Shared-memory size**: with the `libkrun` hypervisor, the RAM + VRAM
addressable memory is limited to 64 GB. So the maximum GPU memory
will be `64GB - RAM`, regardless of the hardware VRAM size.
* This work is pending upstream changes in the VirglRenderer
project.

View File

@@ -22,7 +22,7 @@
**Llama.cpp + ZenDNN**
The llama.cpp ZenDNN backend leverages AMD's optimized matrix multiplication primitives to accelerate inference on AMD CPUs. It utilizes ZenDNN's **LowOHA (Low Overhead Hardware Accelerated)** MatMul operator for efficient GEMM operations with minimal execution overhead, built-in weight caching, and direct access to backend libraries (AOCL BLIS, LibXSMM, OneDNN).
The llama.cpp ZenDNN backend leverages AMD's optimized matrix multiplication primitives to accelerate inference on AMD CPUs. It utilizes ZenDNN's **LowOHA (Low Overhead Hardware Accelerated)** MatMul operator for efficient GEMM operations with minimal execution overhead, built-in weight caching, and direct access to backend libraries (AOCL DLP, LibXSMM, OneDNN).
For more information about ZenDNN, visit: https://www.amd.com/en/developer/zendnn.html
@@ -32,7 +32,7 @@ For more information about ZenDNN, visit: https://www.amd.com/en/developer/zendn
|:-------:|:-------:|:----------------------------------------------:|
| Linux | Support | Ubuntu 20.04, 22.04, 24.04 |
For the latest list of supported operating systems, see the [ZenDNN Supported OS](https://github.com/amd/ZenDNN/blob/zendnnl/README.md#15-supported-os).
For the latest list of supported operating systems, see the [ZenDNN Supported OS](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/README.md#15-supported-os).
## Hardware
@@ -44,9 +44,9 @@ ZenDNN is optimized for AMD EPYC™ processors and AMD Ryzen™ processors based
| CPU Family | Status | Notes |
|:-----------------------------:|:-------:|:----------------------------------:|
| AMD EPYC™ 9005 Series (Turin)| Support | 5th Gen - Zen 5 architecture |
| AMD EPYC™ 9004 Series (Genoa)| Support | 4th Gen - Zen 4 architecture |
| AMD EPYC™ 7003 Series (Milan)| Support | 3rd Gen - Zen 3 architecture |
| AMD EPYC™ 9005 Series (Turin) | Support | 5th Gen - Zen 5 architecture |
| AMD EPYC™ 9004 Series (Genoa) | Support | 4th Gen - Zen 4 architecture |
| AMD EPYC™ 7003 Series (Milan) | Support | 3rd Gen - Zen 3 architecture |
| AMD Ryzen™ AI MAX (Strix Halo)| Support | High-performance mobile processors |
*Notes:*
@@ -61,7 +61,7 @@ The ZenDNN backend currently accelerates **matrix multiplication (MUL_MAT)** ope
| Operation | Status | Notes |
|:-------------|:-------:|:----------------------------------------------:|
| MUL_MAT | | Accelerated via ZenDNN LowOHA MatMul |
| MUL_MAT | Support | Accelerated via ZenDNN LowOHA MatMul |
*Note:* Since only MUL_MAT is accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs).
@@ -104,7 +104,6 @@ If you want to build ZenDNN yourself or use a specific version:
# Clone ZenDNN repository
git clone https://github.com/amd/ZenDNN.git
cd ZenDNN
git checkout zendnnl
# Build and install (requires CMake >= 3.25)
mkdir build && cd build
@@ -114,7 +113,7 @@ cmake --build . --target all
Default installation path: `ZenDNN/build/install`
**For detailed build instructions**, refer to the [ZenDNN README](https://github.com/amd/ZenDNN/blob/zendnnl/README.md).
**For detailed build instructions**, refer to the [ZenDNN README](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/README.md).
**Step 2: Build llama.cpp with custom ZenDNN path**
@@ -146,8 +145,7 @@ Run llama.cpp server with ZenDNN acceleration:
```sh
# Set optimal configuration
export OMP_NUM_THREADS=64 # Adjust to your CPU core count
export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS for best performance
export ZENDNNL_MATMUL_ALGO=1 # Blocked AOCL DLP algo for best performance
# Start server
./build/bin/llama-server \
@@ -160,62 +158,26 @@ export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS for best performance
Access the server at `http://localhost:8080`.
**Performance tips**:
- Set `OMP_NUM_THREADS` to match your physical core count
- Use `ZENDNNL_MATMUL_ALGO=2` for optimal performance
- Use `ZENDNNL_MATMUL_ALGO=1` for optimal performance
- For NUMA systems: `numactl --cpunodebind=0 --membind=0 ./build/bin/llama-server ...`
## Environment Variable
### Build Time
For environment variables related to ZenDNN, refer to the [ZenDNN Environment Variables Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/runtime_env.md).
| Name | Value | Function |
|--------------------|---------------------------------------|---------------------------------------------|
| GGML_ZENDNN | ON/OFF | Enable ZenDNN backend support |
| ZENDNN_ROOT | Path to ZenDNN installation | Set ZenDNN installation directory |
| GGML_OPENMP | ON/OFF (recommended: ON) | Enable OpenMP for multi-threading |
### Performance Optimization
### Runtime
| Name | Value | Function |
|-------------------------|--------------------------|-------------------------------------------------------------------|
| OMP_NUM_THREADS | Number (e.g., 64) | Set number of OpenMP threads (recommended: physical core count) |
| ZENDNNL_MATMUL_ALGO | 0-5 | Select MatMul backend algorithm (see Performance Optimization) |
| ZENDNNL_PROFILE_LOG_LEVEL | 0-4 | Profiling log level (0=disabled, 4=verbose) |
| ZENDNNL_ENABLE_PROFILER | 0 or 1 | Enable detailed profiling (1=enabled) |
| ZENDNNL_API_LOG_LEVEL | 0-4 | API log level (0=disabled, 4=verbose) |
**Example**:
ZenDNN's LowOHA MatMul supports multiple backend algorithms. For **best performance**, use the **Blocked AOCL DLP** algorithm:
```sh
export OMP_NUM_THREADS=64
export ZENDNNL_MATMUL_ALGO=2 # Use Blocked AOCL BLIS for best performance
./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "Test" -n 100
export ZENDNNL_MATMUL_ALGO=1 # Blocked AOCL DLP algo (recommended)
```
## Performance Optimization
### MatMul Algorithm Selection
ZenDNN's LowOHA MatMul supports multiple backend algorithms. For **best performance**, use the **Blocked AOCL BLIS** algorithm:
```sh
export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS (recommended)
```
**Available algorithms**:
| Value | Algorithm | Description |
|:-----:|:-----------------------|:----------------------------------------------|
| 0 | Dynamic Dispatch | Automatic backend selection (default) |
| 1 | AOCL BLIS | AOCL BLIS backend |
| 2 | AOCL BLIS Blocked | **Blocked AOCL BLIS (recommended)** |
| 3 | OneDNN | OneDNN backend |
| 4 | OneDNN Blocked | Blocked OneDNN |
| 5 | LibXSMM | LibXSMM backend |
For more details on available algorithms, see the [ZenDNN MatMul Algorithm Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/runtime_env.md#algorithm-details).
### Profiling and Debugging
For detailed profiling and logging options, refer to the [ZenDNN Logging Documentation](https://github.com/amd/ZenDNN/blob/zendnnl/docs/logging.md).
For detailed profiling and logging options, refer to the [ZenDNN Logging Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/logging.md).
## Known Issues
@@ -245,10 +207,9 @@ A: Currently, ZenDNN primarily supports FP32 and BF16 data types. Quantized mode
A: Ensure:
1. You're using an AMD EPYC or Ryzen processor (Zen 2 or newer)
2. `OMP_NUM_THREADS` is set appropriately (physical core count)
3. `ZENDNNL_MATMUL_ALGO=2` is set for best performance (Blocked AOCL BLIS)
4. You're using a sufficiently large model (small models may not benefit as much)
5. Enable profiling to verify ZenDNN MatMul is being called
2. `ZENDNNL_MATMUL_ALGO=1` is set for best performance (Blocked AOCL DLP)
3. You're using a sufficiently large model (small models may not benefit as much)
4. Enable profiling to verify ZenDNN MatMul is being called
### **GitHub Contribution**:
Please add the **[ZenDNN]** prefix/tag in issues/PRs titles to help the ZenDNN-team check/address them without delay.

View File

@@ -730,10 +730,6 @@ extern "C" {
GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
GGML_DEPRECATED(
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
"use ggml_row_size() instead");
GGML_API const char * ggml_type_name(enum ggml_type type);
GGML_API const char * ggml_op_name (enum ggml_op op);
GGML_API const char * ggml_op_symbol(enum ggml_op op);

View File

@@ -141,27 +141,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ
namespace ggml::cpu::amx {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};
if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
// src1 must be host buffer
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
// src1 must be float32
if (op->src[1]->type == GGML_TYPE_F32) {
return true;
}
if (op->op != GGML_OP_MUL_MAT) {
return false;
}
return false;
auto * src0 = op->src[0];
auto * src1 = op->src[1];
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
return false;
}
if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) {
return false;
}
if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) {
return false;
}
if (op->ne[0] % (TILE_N * 2)) {
return false;
}
int alignment;
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q8_0:
alignment = TILE_K;
break;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ4_XS:
alignment = 256; // QK_K
break;
case GGML_TYPE_F16:
alignment = 16;
break;
default:
return false;
}
if (src0->ne[0] % alignment) {
return false;
}
if (src1->type != GGML_TYPE_F32) {
return false;
}
return true;
}
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {

View File

@@ -1,4 +1,3 @@
#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Wpedantic"
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
@@ -202,35 +201,27 @@ struct tile_config_t{
// advanced-matrix-extensions-intrinsics-functions.html
//
#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
void ggml_tile_config_init(void) {
static thread_local bool is_first_time = true;
inline void ggml_tile_config_init(void) {
static thread_local bool done = false;
if (!is_first_time) {
if (done) {
return;
}
static thread_local tile_config_t tc;
tile_config_t current_tc;
_tile_storeconfig(&current_tc);
alignas(64) tile_config_t tc = {};
tc.palette_id = 1;
tc.start_row = 0;
tc.rows[0] = 8; tc.colsb[0] = 64;
tc.rows[1] = 8; tc.colsb[1] = 64;
tc.rows[2] = 16; tc.colsb[2] = 32;
tc.rows[3] = 16; tc.colsb[3] = 32;
tc.rows[4] = 16; tc.colsb[4] = 64;
tc.rows[5] = 16; tc.colsb[5] = 64;
tc.rows[6] = 16; tc.colsb[6] = 64;
tc.rows[7] = 16; tc.colsb[7] = 64;
// load only when config changes
if (tc.palette_id == 0 || (memcmp(&current_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
memcmp(&current_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
tc.palette_id = 1;
tc.start_row = 0;
TC_CONFIG_TILE(TMM0, 8, 64);
TC_CONFIG_TILE(TMM1, 8, 64);
TC_CONFIG_TILE(TMM2, 16, 32);
TC_CONFIG_TILE(TMM3, 16, 32);
TC_CONFIG_TILE(TMM4, 16, 64);
TC_CONFIG_TILE(TMM5, 16, 64);
TC_CONFIG_TILE(TMM6, 16, 64);
TC_CONFIG_TILE(TMM7, 16, 64);
_tile_loadconfig(&tc);
}
is_first_time = false;
_tile_loadconfig(&tc);
done = true;
}
// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
@@ -268,33 +259,6 @@ int get_row_size(int K) {
return row_size;
}
// vectorized dtype conversion
inline float FP16_TO_FP32(ggml_half val) {
__m256i v = _mm256_setr_epi16(
val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
__m512 o = _mm512_cvtph_ps(v);
return _mm512_cvtss_f32(o);
}
inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
__m256i v = _mm256_set1_epi16(val);
return _mm512_cvtph_ps(v);
}
// horizontal reduce
inline float _mm512_reduce_max_ps(const __m512 x) {
__m512 v = x;
__m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
v = _mm512_max_ps(v, v1);
v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
v = _mm512_max_ps(v, v1);
v1 = _mm512_shuffle_ps(v, v, 0x4E);
v = _mm512_max_ps(v, v1);
v1 = _mm512_shuffle_ps(v, v, 0xB1);
v = _mm512_max_ps(v, v1);
return _mm512_cvtss_f32(v);
}
// transpose utils
#define SHUFFLE_EPI32(a, b, mask) \
_mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
@@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \
tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \
K, (const float *)src1->data + mb_start * K, \
(const type *)src0->data + nb_start * K, \
(float *)dst->data + mb_start * ldc + nb_start, ldc);
K, (const float *)src1->data + src1_offset + mb_start * K, \
(const type *)src0->data + src0_offset + nb_start * K, \
(float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc)
// re-organize in the format {NB, KB, TILE_SIZE}:
@@ -2019,11 +1983,11 @@ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, B
}
};
#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
KB, (const char *)wdata + 0 * row_size_A, \
(const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
(float *) dst->data + 0 * N + nb_start, ldc)
#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
KB, wdata_batch, \
(const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
(float *) dst->data + dst_offset + nb_start, ldc)
template <typename TA, typename TB, typename TC, int BLOCK_K,
typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>
@@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
_tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
if (need_unpack) {
unpack_B<TB>(Tile1, B_blk0);
unpack_B<TB>(Tile1, B_blk1);
_tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
} else {
_tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
@@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d
});
}
// ne2 is passed explicitly to help compiler optimize repeated calls
inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) {
const int64_t i2 = batch_idx % ne2;
const int64_t i3 = batch_idx / ne2;
return i3 * t->nb[3] + i2 * t->nb[2];
}
size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
struct ggml_tensor * src0 = dst->src[0];
@@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
const int M = dst->ne[1];
const int K = src0->ne[0];
const int64_t n_batch = dst->ne[2] * dst->ne[3];
size_t desired_wsize = 0;
GGML_DISPATCH_QTYPES(TYPE, [&] {
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
desired_wsize = M * row_size_A;
desired_wsize = n_batch * M * row_size_A;
});
return desired_wsize;
@@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
// src1: input in shape of {M, K}, float32
// dst: output in shape of {M, N}, float32
//
// the function performs: dst = src1 @ src0.T
// the function performs: dst = src1 @ src0.T for each batch
//
void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) {
struct ggml_tensor * src0 = dst->src[0];
@@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
const int K = src0->ne[0];
const int ldc = dst->nb[1] / dst->nb[0];
const int64_t ne2 = dst->ne[2];
const int64_t n_batch = ne2 * dst->ne[3];
if (is_floating_type) {
constexpr int BLOCK_M = 4;
constexpr int BLOCK_N = 6;
const int MB = div_up(M, BLOCK_M);
const int NB = div_up(N, BLOCK_N);
parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
for (int i = begin; i < end; ++i) {
int mb = i / NB;
int nb = i % NB;
int batch_idx = i / (MB * NB);
int remaining = i % (MB * NB);
int mb = remaining / NB;
int nb = remaining % NB;
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(BLOCK_M, M - mb_start);
@@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
void * wdata = params->wdata;
//TODO: performance improvement: merge quant A
if (params->ith == 0) {
// if (params->ith == 0) {
GGML_DISPATCH_QTYPES(TYPE, [&] {
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
const size_t desired_wsize = M * row_size_A;
const size_t desired_wsize = n_batch * M * row_size_A;
if (params->wsize < desired_wsize) {
GGML_ABORT("insufficient work space size");
}
@@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
// Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
const float * A_data = static_cast<const float *>(src1->data);
for (int m = 0; m < M; ++m) {
from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
}
parallel_for_ggml(params, n_batch, [&](int begin, int end) {
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
const float * A_data = (const float *)((const char *)src1->data + src1_offset);
char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
for (int m = 0; m < M; ++m) {
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
}
}
});
});
}
// }
ggml_barrier(params->threadpool);
@@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
constexpr int BLOCK_N = TILE_N * kTilesN;
const int NB = div_up(N, BLOCK_N);
parallel_for_ggml(params, NB, [&](int begin, int end) {
parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) {
GGML_DISPATCH_QTYPES(TYPE, [&] {
const int KB = K / blck_size;
const int TILE_SIZE = get_tile_size<type>();
const int row_size_A = KB * sizeof(vec_dot_type);
for (int i = begin; i < end; ++i) {
int nb = i;
int batch_idx = i / NB;
int nb = i % NB;
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A;
int nb_start = nb * BLOCK_N;
int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
@@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
const int MB = div_up(M, BLOCK_M);
const int NB = div_up(N, BLOCK_N);
parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
// init tile config for each thread
ggml_tile_config_init();
@@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
const int row_size_A = KB * sizeof(vec_dot_type);
for (int i = begin; i < end; ++i) {
int mb = i / NB;
int nb = i % NB;
int batch_idx = i / (MB * NB);
int remaining = i % (MB * NB);
int mb = remaining / NB;
int nb = remaining % NB;
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A;
int mb_start = mb * BLOCK_M;
int mb_size = std::min(BLOCK_M, M - mb_start);
@@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(
mb_size, nb_size, KB,
(const char *)wdata + mb_start * row_size_A,
(const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
(float *) dst->data + mb_start * N + nb_start, ldc);
wdata_batch + mb_start * row_size_A,
(const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
(float *) dst->data + dst_offset + mb_start * N + nb_start, ldc);
}
});
});

View File

@@ -55,7 +55,11 @@ void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const int32_t* src2_d = (const int32_t*)src2->data;
float* dst_d = (float*)dst->data;
int threads = std::min((int)ne00, 768); // cols
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
int threads = std::min((unsigned int)ne00, max_work_group_size); // cols
ctx.stream()->parallel_for(
sycl::nd_range<3>(
sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads),

View File

@@ -11,8 +11,8 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13,
int s00, int s01, int s02, int s03,
int s10, int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) {
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -44,7 +44,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
for (int i0 = i0s; i0 < ne0;
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);
}
}
@@ -53,8 +53,8 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13,
int s00, int s01, int s02, int s03,
int s10, int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -82,7 +82,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);
}
@@ -95,7 +95,8 @@ struct bin_bcast_sycl {
const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
const bool src1_is_contiguous, const bool src0_is_permuted, const bool src1_is_permuted,
queue_ptr stream) {
int nr0 = ne10 / ne0;
int nr1 = ne11/ne1;
int nr2 = ne12/ne2;
@@ -123,7 +124,7 @@ struct bin_bcast_sycl {
cnb[3] *= cne[3];
};
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
if (src0_is_contiguous && src1_is_contiguous && !src0_is_permuted && !src1_is_permuted) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
@@ -164,7 +165,7 @@ struct bin_bcast_sycl {
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t);
// size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
@@ -196,9 +197,6 @@ struct bin_bcast_sycl {
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0/2LL, 1LL);
@@ -232,8 +230,8 @@ struct bin_bcast_sycl {
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast_unravel<bin_op>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
s03, s11, s12, s13, item_ct1);
ne10, ne11, ne12, ne13, s1, s2, s3, s00, s01, s02,
s03, s10, s11, s12, s13, item_ct1);
});
}
} else {
@@ -251,7 +249,7 @@ struct bin_bcast_sycl {
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
ne2, ne3, ne10, ne11, ne12, ne13,
s1, s2, s3, s01, s02, s03, s11, s12, s13,
s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13,
item_ct1);
});
}
@@ -268,24 +266,27 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10,
ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01,
ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst),
nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02,
ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream);
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream);
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
ggml_type_name(src0->type), ggml_type_name(src1->type));

View File

@@ -7,9 +7,21 @@
#include <cstdint>
static uint32_t validate_graph_operation(size_t cgraph_size, uint32_t shmem_res_id, const char * operation) {
if (cgraph_size == 0) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Zero-size computation graph\n", operation);
return 1;
}
// place-holder: validate that the size of shmem_res_id is <= cgraph_size
// need to add another method in the Virgl->APIR callback interface
GGML_UNUSED(shmem_res_id);
return 0; // Valid
}
uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(enc);
static bool async_backend_initialized = false;
static bool async_backend;
@@ -34,10 +46,26 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v
size_t cgraph_size;
apir_decode_size_t(dec, &cgraph_size);
if (validate_graph_operation(cgraph_size, shmem_res_id, __func__) != 0) {
apir_decoder_set_fatal(dec);
return 1;
}
apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size);
ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size);
if (!cgraph || apir_decoder_get_fatal(&secondary_dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to deserialize computation graph\n", __func__);
return 1;
}
if (cgraph->n_nodes < 0 || cgraph->n_leafs < 0) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid negative node/leaf count: nodes=%d leafs=%d\n", __func__,
cgraph->n_nodes, cgraph->n_leafs);
return 1;
}
ggml_status status;
#if APIR_BACKEND_CHECK_SUPPORTS_OP == 1
for (int idx = 0; idx < cgraph->n_nodes; idx++) {
@@ -45,7 +73,8 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v
if (dev->iface.supports_op(dev, op)) {
continue;
}
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op));
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", __func__, idx,
ggml_op_desc(op));
status = GGML_STATUS_ABORTED;
apir_encode_ggml_status(enc, &status);
@@ -53,9 +82,17 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v
return 0;
}
#endif
// Check if backend is properly initialized
if (!bck) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Backend not initialized (bck is null)\n", __func__);
return 1;
}
status = bck->iface.graph_compute(bck, cgraph);
if (async_backend) {
if (async_backend && bck->iface.synchronize) {
bck->iface.synchronize(bck);
}

View File

@@ -85,7 +85,19 @@ uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * d
const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);
size_t value = buft->iface.get_alloc_size(buft, op);
// Check for decode error
if (op == nullptr) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to decode tensor\n", __func__);
apir_decoder_set_fatal(dec);
return 1;
}
size_t value;
if (buft->iface.get_alloc_size) {
value = buft->iface.get_alloc_size(buft, op);
} else {
value = ggml_nbytes(op); // Default fallback
}
apir_encode_size_t(enc, &value);

View File

@@ -6,11 +6,26 @@
#include <cstdint>
static uint32_t validate_buffer_operation(size_t offset, size_t size, const char * operation) {
// Only check for critical integer overflow - no arbitrary size limits
if (offset > SIZE_MAX - size) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Integer overflow in offset+size: %zu + %zu\n", operation, offset, size);
return 1;
}
return 0; // Valid
}
uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer);
apir_encode_uintptr_t(enc, &base);
@@ -24,6 +39,11 @@ uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
ggml_tensor * tensor;
// safe to remove the const qualifier here
tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);
@@ -37,6 +57,10 @@ uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl
size_t size;
apir_decode_size_t(dec, &size);
if (validate_buffer_operation(offset, size, __func__) != 0) {
return 1;
}
void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
if (!shmem_data) {
@@ -56,6 +80,11 @@ uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
const ggml_tensor * tensor;
// safe to remove the const qualifier here
tensor = apir_decode_ggml_tensor(dec);
@@ -69,6 +98,10 @@ uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl
size_t size;
apir_decode_size_t(dec, &size);
if (validate_buffer_operation(offset, size, __func__) != 0) {
return 1;
}
void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
if (!shmem_data) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__);
@@ -86,6 +119,11 @@ uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
const ggml_tensor * src;
// safe to remove the const qualifier here
src = apir_decode_ggml_tensor(dec);
@@ -105,6 +143,11 @@ uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
uint8_t value;
apir_decode_uint8_t(dec, &value);
@@ -120,6 +163,11 @@ uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virg
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
if (!apir_untrack_backend_buffer(buffer)) {
GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: unknown buffer %p\n", __func__, (void *) buffer);
return 1;

View File

@@ -1,6 +1,6 @@
#include "backend-dispatched.h"
#include "backend-virgl-apir.h"
#include "backend-virgl-apir.h"
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
#include "ggml-impl.h"
@@ -28,19 +28,24 @@ uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) {
return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED;
}
if (!reg->iface.get_device_count(reg)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device found\n", __func__);
size_t device_count = reg->iface.get_device_count(reg);
if (!device_count) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: no device found\n", __func__);
return APIR_BACKEND_INITIALIZE_NO_DEVICE;
}
dev = reg->iface.get_device(reg, 0);
if (!dev) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device received\n", __func__);
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: failed to get device\n", __func__);
return APIR_BACKEND_INITIALIZE_NO_DEVICE;
}
bck = dev->iface.init_backend(dev, NULL);
if (!bck) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed\n", __func__);
return APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED;
}
return APIR_BACKEND_INITIALIZE_SUCCESS;
}

View File

@@ -32,64 +32,6 @@ uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virg
/* backend */
uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
static inline const char * backend_dispatch_command_name(ApirBackendCommandType type) {
switch (type) {
/* device */
case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT:
return "backend_device_get_device_count";
case APIR_COMMAND_TYPE_DEVICE_GET_COUNT:
return "backend_device_get_count";
case APIR_COMMAND_TYPE_DEVICE_GET_NAME:
return "backend_device_get_name";
case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION:
return "backend_device_get_description";
case APIR_COMMAND_TYPE_DEVICE_GET_TYPE:
return "backend_device_get_type";
case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY:
return "backend_device_get_memory";
case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP:
return "backend_device_supports_op";
case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE:
return "backend_device_get_buffer_type";
case APIR_COMMAND_TYPE_DEVICE_GET_PROPS:
return "backend_device_get_props";
case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR:
return "backend_device_buffer_from_ptr";
/* buffer-type */
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME:
return "backend_buffer_type_get_name";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT:
return "backend_buffer_type_get_alignment";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE:
return "backend_buffer_type_get_max_size";
case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST:
return "backend_buffer_type_is_host (DEPRECATED)";
case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER:
return "backend_buffer_type_alloc_buffer";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE:
return "backend_buffer_type_get_alloc_size";
/* buffer */
case APIR_COMMAND_TYPE_BUFFER_GET_BASE:
return "backend_buffer_get_base";
case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR:
return "backend_buffer_set_tensor";
case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR:
return "backend_buffer_get_tensor";
case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR:
return "backend_buffer_cpy_tensor";
case APIR_COMMAND_TYPE_BUFFER_CLEAR:
return "backend_buffer_clear";
case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER:
return "backend_buffer_free_buffer";
/* backend */
case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE:
return "backend_backend_graph_compute";
default:
return "unknown";
}
}
extern "C" {
static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {

View File

@@ -1,5 +1,6 @@
#pragma once
// clang-format off
#include <cstdint>
#include <cstddef>
@@ -10,6 +11,7 @@
#include "shared/apir_backend.h"
#include "shared/apir_cs.h"
#include "shared/apir_cs_ggml.h"
// clang-format on
#define GGML_VIRTGPU_BCK "ggml-virtgpu-backend: "

View File

@@ -19,7 +19,7 @@ struct virgl_apir_callbacks {
};
extern "C" {
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs);
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs);
void apir_backend_deinit(uint32_t virgl_ctx_id);
uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id,
virgl_apir_callbacks * virgl_cbs,

View File

@@ -1,6 +1,5 @@
#include "backend-dispatched.h"
#include "backend-virgl-apir.h"
#include "shared/api_remoting.h"
#include "shared/apir_backend.h"
#include "shared/apir_cs.h"
@@ -17,10 +16,10 @@
#define GGML_DEFAULT_BACKEND_REG "ggml_backend_init"
static void * backend_library_handle = NULL;
static FILE * apir_logfile = NULL;
static FILE * apir_logfile = NULL;
static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) {
FILE * logfile = (FILE *)user_data;
FILE * logfile = (FILE *) user_data;
fprintf(logfile, "[%d] %s", level, text);
fflush(logfile);
}
@@ -48,9 +47,9 @@ void apir_backend_deinit(uint32_t virgl_ctx_id) {
}
#define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path"
#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg"
#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg"
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs) {
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs) {
const char * dlsym_error;
const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV);
@@ -63,15 +62,13 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct
}
}
const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY);
const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY);
const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY);
const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG;
const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG;
if (!library_name) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
"%s: cannot open the GGML library: env var '%s' not defined\n",
__func__, APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV);
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: env var '%s' not defined\n", __func__,
APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV);
return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
}
@@ -79,16 +76,14 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct
backend_library_handle = dlopen(library_name, RTLD_LAZY);
if (!backend_library_handle) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
"%s: cannot open the GGML library: %s\n", __func__, dlerror());
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: %s\n", __func__, dlerror());
return APIR_LOAD_LIBRARY_CANNOT_OPEN;
}
if (!library_reg) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
"%s: cannot register the GGML library: env var '%s' not defined\n",
__func__, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV);
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot register the GGML library: env var '%s' not defined\n", __func__,
APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV);
return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
}
@@ -96,11 +91,9 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct
void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg);
dlsym_error = dlerror();
if (dlsym_error) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
"%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n",
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n",
__func__, library_reg, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error);
return APIR_LOAD_LIBRARY_SYMBOL_MISSING;
}
@@ -132,13 +125,12 @@ uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id,
virgl_apir_context ctx = {
.ctx_id = virgl_ctx_id,
.iface = virgl_cbs,
.iface = virgl_cbs,
};
if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
"%s: Received an invalid dispatch index (%d >= %d)\n",
__func__, cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT);
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Received an invalid dispatch index (%d >= %d)\n", __func__, cmd_type,
APIR_BACKEND_DISPATCH_TABLE_COUNT);
return APIR_BACKEND_FORWARD_INDEX_INVALID;
}

View File

@@ -16,28 +16,32 @@ enum ApirCommandType {
APIR_COMMAND_TYPE_LOADLIBRARY = 1,
APIR_COMMAND_TYPE_FORWARD = 2,
APIR_COMMAND_TYPE_LENGTH = 3,
APIR_COMMAND_TYPE_LENGTH = 3,
};
typedef uint64_t ApirCommandFlags;
enum ApirLoadLibraryReturnCode {
APIR_LOAD_LIBRARY_SUCCESS = 0,
// these error codes are returned by the Virglrenderer APIR component
APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1,
APIR_LOAD_LIBRARY_ALREADY_LOADED = 2,
APIR_LOAD_LIBRARY_ENV_VAR_MISSING = 3,
APIR_LOAD_LIBRARY_CANNOT_OPEN = 4,
APIR_LOAD_LIBRARY_SYMBOL_MISSING = 5,
APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, // anything above this is a APIR backend library initialization return code
// any value greater than this is an APIR *backend library* initialization return code
APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6,
};
enum ApirForwardReturnCode {
APIR_FORWARD_SUCCESS = 0,
APIR_FORWARD_NO_DISPATCH_FCT = 1,
APIR_FORWARD_TIMEOUT = 2,
APIR_FORWARD_BASE_INDEX = 3, // anything above this is a APIR backend library forward return code
} ;
APIR_FORWARD_SUCCESS = 0,
// these error codes are returned by the Virglrenderer APIR component
APIR_FORWARD_NO_DISPATCH_FCT = 1,
APIR_FORWARD_TIMEOUT = 2,
APIR_FORWARD_FAILED_TO_SYNC_STREAMS = 3,
// any value greater than this index an APIR *backend library* forward return code
APIR_FORWARD_BASE_INDEX = 4,
};
__attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) {
switch (type) {
@@ -82,6 +86,7 @@ __attribute__((unused)) static const char * apir_forward_error(ApirForwardReturn
APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS);
APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT);
APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT);
APIR_FORWARD_ERROR(APIR_FORWARD_FAILED_TO_SYNC_STREAMS);
APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX);
return "Unknown APIR_COMMAND_TYPE_FORWARD error";

View File

@@ -34,3 +34,61 @@ typedef enum ApirBackendCommandType {
// last command_type index + 1
APIR_BACKEND_DISPATCH_TABLE_COUNT = 23,
} ApirBackendCommandType;
static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {
switch (type) {
/* device */
case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT:
return "device_get_device_count";
case APIR_COMMAND_TYPE_DEVICE_GET_COUNT:
return "device_get_count";
case APIR_COMMAND_TYPE_DEVICE_GET_NAME:
return "device_get_name";
case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION:
return "device_get_description";
case APIR_COMMAND_TYPE_DEVICE_GET_TYPE:
return "device_get_type";
case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY:
return "device_get_memory";
case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP:
return "device_supports_op";
case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE:
return "device_get_buffer_type";
case APIR_COMMAND_TYPE_DEVICE_GET_PROPS:
return "device_get_props";
case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR:
return "device_buffer_from_ptr";
/* buffer-type */
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME:
return "buffer_type_get_name";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT:
return "buffer_type_get_alignment";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE:
return "buffer_type_get_max_size";
case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST:
return "buffer_type_is_host";
case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER:
return "buffer_type_alloc_buffer";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE:
return "buffer_type_get_alloc_size";
/* buffer */
case APIR_COMMAND_TYPE_BUFFER_GET_BASE:
return "buffer_get_base";
case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR:
return "buffer_set_tensor";
case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR:
return "buffer_get_tensor";
case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR:
return "buffer_cpy_tensor";
case APIR_COMMAND_TYPE_BUFFER_CLEAR:
return "buffer_clear";
case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER:
return "buffer_free_buffer";
/* backend */
case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE:
return "backend_graph_compute";
default:
return "unknown";
}
}

View File

@@ -14,7 +14,7 @@
#define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED 6
#define APIR_BACKEND_INITIALIZE_ALREADY_INITED 7
#define APIR_BACKEND_INITIALIZE_NO_DEVICE 8
#define APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED 9
// new entries here need to be added to the apir_backend_initialize_error function below
@@ -39,6 +39,10 @@ static const char * apir_backend_initialize_error(int code) {
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_ALREADY_INITED);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_NO_DEVICE);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED);
return "Unknown APIR_BACKEND_INITIALIZE error:/";

View File

@@ -13,7 +13,6 @@ struct apir_encoder {
const char * start;
const char * end;
bool fatal;
};
struct apir_decoder {
@@ -28,8 +27,8 @@ struct apir_decoder {
static apir_decoder apir_new_decoder(const char * ptr, size_t size) {
apir_decoder dec = {
.cur = ptr,
.end = ptr + size,
.cur = ptr,
.end = ptr + size,
.fatal = false,
};
@@ -79,10 +78,7 @@ static inline bool apir_decoder_get_fatal(const apir_decoder * dec) {
* encode peek
*/
static inline bool apir_decoder_peek_internal(apir_decoder * dec,
size_t size,
void * val,
size_t val_size) {
static inline bool apir_decoder_peek_internal(apir_decoder * dec, size_t size, void * val, size_t val_size) {
assert(val_size <= size);
if (unlikely(size > (size_t) (dec->end - dec->cur))) {
@@ -332,8 +328,7 @@ static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t
static inline void * apir_decoder_alloc_array(size_t size, size_t count) {
size_t alloc_size;
if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) {
GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n",
__func__, size, count);
GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n", __func__, size, count);
return NULL;
}
@@ -352,20 +347,19 @@ static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) {
/* apir_buffer_type_host_handle_t */
static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc,
static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc,
const apir_buffer_type_host_handle_t * val) {
apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
}
static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec,
static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec,
apir_buffer_type_host_handle_t * val) {
apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
}
/* apir_buffer_host_handle_t */
static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc,
const apir_buffer_host_handle_t * val) {
static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, const apir_buffer_host_handle_t * val) {
apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
}

View File

@@ -1,11 +1,10 @@
#include "ggml-impl.h"
#include "apir_cs.h"
#include "apir_cs_rpc.h"
#include "ggml-impl.h"
// ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer);
static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc,
const apir_buffer_host_handle_t * handle);
static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle);
static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec);
@@ -22,8 +21,7 @@ static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder
return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
}
static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec,
uint32_t n_tensors) {
static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, uint32_t n_tensors) {
size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors;
return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
@@ -45,9 +43,9 @@ static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) {
}
ggml_init_params params{
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
/*.mem_size =*/ggml_tensor_overhead(),
/*.mem_buffer =*/NULL,
/*.no_alloc =*/true,
};
ggml_context * ctx = ggml_init(params);
@@ -105,6 +103,19 @@ static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec)
apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size);
// SECURITY: Validate buffer handle against tracked buffers to prevent
// guest VM from providing arbitrary host memory addresses
if (buffer) {
extern std::unordered_set<ggml_backend_buffer_t> backend_buffers;
if (backend_buffers.find(buffer) == backend_buffers.end()) {
GGML_LOG_WARN("ggml-virtgpu-backend: %s: Invalid buffer handle from guest: %p\n", __func__,
(void *) buffer);
// Set fatal flag to prevent further processing with invalid handle
apir_decoder_set_fatal(dec);
return NULL;
}
}
return buffer;
}

View File

@@ -1,3 +1,6 @@
#pragma once
// clang-format off
#include "ggml.h"
#include "ggml-backend-impl.h"
@@ -5,6 +8,7 @@
#include <unordered_set>
#include <vector>
#include <cstdint>
// clang-format on
// ggml_tensor is serialized into apir_rpc_tensor
struct apir_rpc_tensor {

View File

@@ -34,6 +34,7 @@ static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml
static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
virtgpu * gpu = BUFT_TO_GPU(buft);
// Return the prefixed name that was built once during initialization
return gpu->cached_buffer_type.name;
}
@@ -53,9 +54,8 @@ static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buff
const ggml_tensor * tensor) {
virtgpu * gpu = BUFT_TO_GPU(buft);
if (tensor->buffer == NULL
|| !tensor->buffer->context
|| !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) {
if (tensor->buffer == NULL || !tensor->buffer->context ||
!buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) {
return ggml_nbytes(tensor);
}

View File

@@ -3,6 +3,7 @@
static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);
// Return the prefixed name that was built once during initialization
return gpu->cached_device_info.name;
}
@@ -22,7 +23,7 @@ static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_bac
static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
virtgpu * gpu = DEV_TO_GPU(dev);
*free = gpu->cached_device_info.memory_free;
*free = gpu->cached_device_info.memory_free;
*total = gpu->cached_device_info.memory_total;
}
@@ -72,7 +73,7 @@ static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_
ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);
static std::atomic<bool> initialized = false;
static std::atomic<bool> initialized = false;
static ggml_backend_buffer_type buft;
if (!initialized) {
@@ -95,7 +96,7 @@ ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_bac
static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);
static std::atomic<bool> initialized = false;
static std::atomic<bool> initialized = false;
static ggml_backend_buffer_type buft;
if (!initialized) {

View File

@@ -7,8 +7,8 @@
void ggml_virtgpu_cleanup(virtgpu * gpu);
static virtgpu * apir_initialize() {
static virtgpu * gpu = NULL;
static std::atomic<bool> initialized = false;
static virtgpu * gpu = NULL;
static std::atomic<bool> initialized = false;
if (initialized) {
// fast track
@@ -31,29 +31,53 @@ static virtgpu * apir_initialize() {
}
// Pre-fetch and cache all device information, it will not change
gpu->cached_device_info.description = apir_device_get_description(gpu);
gpu->cached_device_info.description = apir_device_get_description(gpu);
if (!gpu->cached_device_info.description) {
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device description", __func__);
}
gpu->cached_device_info.name = apir_device_get_name(gpu);
if (!gpu->cached_device_info.name) {
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device name", __func__);
}
gpu->cached_device_info.device_count = apir_device_get_count(gpu);
gpu->cached_device_info.type = apir_device_get_type(gpu);
apir_device_get_memory(gpu,
&gpu->cached_device_info.memory_free,
&gpu->cached_device_info.memory_total);
{
// Get the remote name and create prefixed version
char * rmt_device_name = apir_device_get_name(gpu);
if (!rmt_device_name) {
GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu device name", __func__);
}
size_t device_name_len = strlen(rmt_device_name) + 11; // "[virtgpu] " + null terminator
gpu->cached_device_info.name = (char *) malloc(device_name_len);
if (!gpu->cached_device_info.name) {
free(rmt_device_name);
GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed device name", __func__);
}
snprintf(gpu->cached_device_info.name, device_name_len, "[virtgpu] %s", rmt_device_name);
free(rmt_device_name);
}
apir_device_get_memory(gpu, &gpu->cached_device_info.memory_free, &gpu->cached_device_info.memory_total);
apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu);
gpu->cached_buffer_type.host_handle = buft_host_handle;
gpu->cached_buffer_type.name = apir_buffer_type_get_name(gpu, buft_host_handle);
if (!gpu->cached_buffer_type.name) {
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu buffer type name", __func__);
{
// Get the remote name and create prefixed version
char * rmt_name = apir_buffer_type_get_name(gpu, buft_host_handle);
if (!rmt_name) {
GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu buffer type name", __func__);
}
size_t prefixed_len = strlen(rmt_name) + 11; // "[virtgpu] " + null terminator
gpu->cached_buffer_type.name = (char *) malloc(prefixed_len);
if (!gpu->cached_buffer_type.name) {
free(rmt_name);
GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed buffer type name", __func__);
}
snprintf(gpu->cached_buffer_type.name, prefixed_len, "[virtgpu] %s", rmt_name);
free(rmt_name);
}
gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle);
gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle);
gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle);
gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle);
initialized = true;
}
@@ -98,7 +122,7 @@ static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) {
static std::atomic<bool> initialized = false;
if (initialized) {
return; // fast track
return; // fast track
}
{

View File

@@ -1,5 +1,5 @@
#include "ggml-remoting.h"
#include "../../include/ggml-virtgpu.h"
#include "ggml-remoting.h"
static const char * ggml_backend_remoting_get_name(ggml_backend_t backend) {
UNUSED(backend);

View File

@@ -9,7 +9,7 @@
#include <string>
#define GGML_VIRTGPU_NAME "ggml-virtgpu"
#define GGML_VIRTGPU "ggml-virtgpu: "
#define GGML_VIRTGPU "ggml-virtgpu: "
// USE_ALWAYS_TRUE_SUPPORTS_OP: 1 is fast, 0 avoid micro-benchmark crashes

View File

@@ -3,7 +3,7 @@
#include <stdint.h>
struct virgl_renderer_capset_apir {
uint32_t apir_version;
uint32_t supports_blob_resources;
uint32_t reserved[4]; // For future expansion
uint32_t apir_version;
uint32_t supports_blob_resources;
uint32_t reserved[4]; // For future expansion
};

View File

@@ -145,8 +145,31 @@ class RemotingCodebaseGenerator:
enum_lines.append(f" APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},")
enum_lines.append("} ApirBackendCommandType;")
# Generate function name mapping
func_lines = []
func_lines.append("static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {")
func_lines.append(" switch (type) {")
current_group = None
for func in functions:
# Add comment for new group
if func['group_name'] != current_group:
func_lines.append(f" /* {func['group_description']} */")
current_group = func['group_name']
# Generate clean function name without backend_ prefix
clean_name = f"{func['group_name']}_{func['function_name']}"
func_lines.append(f" case {func['enum_name']}:")
func_lines.append(f" return \"{clean_name}\";")
func_lines.append("")
func_lines.append(" default:")
func_lines.append(" return \"unknown\";")
func_lines.append(" }")
func_lines.append("}")
# Full header template
header_content = NL.join(enum_lines) + "\n"
header_content = NL.join(enum_lines) + "\n\n" + NL.join(func_lines) + "\n"
return header_content
@@ -170,19 +193,6 @@ class RemotingCodebaseGenerator:
decl_lines.append(f"{signature} {func['backend_function']}({params});")
# Switch cases
switch_lines = []
current_group = None
for func in functions:
if func['group_name'] != current_group:
switch_lines.append(f" /* {func['group_description']} */")
current_group = func['group_name']
deprecated = " (DEPRECATED)" if func['deprecated'] else ""
switch_lines.append(f" case {func['enum_name']}: return \"{func['backend_function']}{deprecated}\";")
# Dispatch table
table_lines = []
current_group = None
@@ -201,15 +211,6 @@ class RemotingCodebaseGenerator:
{NL.join(decl_lines)}
static inline const char *backend_dispatch_command_name(ApirBackendCommandType type)
{{
switch (type) {{
{NL.join(switch_lines)}
default: return "unknown";
}}
}}
extern "C" {{
static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{
{NL.join(table_lines)}

View File

@@ -17,8 +17,8 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) {
size_t cgraph_size = apir_serialize_ggml_cgraph(cgraph, cgraph_data);
virtgpu_shmem temp_shmem; // Local storage for large buffers
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
if (cgraph_size <= gpu->data_shmem.mmap_size) {
// Lock mutex before using shared data_shmem buffer
@@ -26,7 +26,7 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) {
GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
}
using_shared_shmem = true;
shmem = &gpu->data_shmem;
shmem = &gpu->data_shmem;
} else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) {
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);
}

View File

@@ -62,7 +62,9 @@ size_t apir_buffer_type_get_max_size(virtgpu * gpu, apir_buffer_type_host_handle
return max_size;
}
apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, size_t size) {
apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu,
apir_buffer_type_host_handle_t host_handle,
size_t size) {
apir_encoder * encoder;
apir_decoder * decoder;
ApirForwardReturnCode ret;
@@ -84,7 +86,9 @@ apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_t
return buffer_context;
}
size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, const ggml_tensor * op) {
size_t apir_buffer_type_get_alloc_size(virtgpu * gpu,
apir_buffer_type_host_handle_t host_handle,
const ggml_tensor * op) {
apir_encoder * encoder;
apir_decoder * decoder;
ApirForwardReturnCode ret;

View File

@@ -35,8 +35,8 @@ void apir_buffer_set_tensor(virtgpu * gpu,
apir_encode_ggml_tensor(encoder, tensor);
virtgpu_shmem temp_shmem; // Local storage for large buffers
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
if (size <= gpu->data_shmem.mmap_size) {
// Lock mutex before using shared data_shmem buffer
@@ -44,7 +44,7 @@ void apir_buffer_set_tensor(virtgpu * gpu,
GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
}
using_shared_shmem = true;
shmem = &gpu->data_shmem;
shmem = &gpu->data_shmem;
} else if (virtgpu_shmem_create(gpu, size, shmem)) {
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);
@@ -86,8 +86,8 @@ void apir_buffer_get_tensor(virtgpu * gpu,
apir_encode_ggml_tensor(encoder, tensor);
virtgpu_shmem temp_shmem; // Local storage for large buffers
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
if (size <= gpu->data_shmem.mmap_size) {
// Lock mutex before using shared data_shmem buffer
@@ -95,7 +95,7 @@ void apir_buffer_get_tensor(virtgpu * gpu,
GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
}
using_shared_shmem = true;
shmem = &gpu->data_shmem;
shmem = &gpu->data_shmem;
} else if (virtgpu_shmem_create(gpu, size, shmem)) {
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);

View File

@@ -26,7 +26,7 @@ char * apir_device_get_name(virtgpu * gpu) {
REMOTE_CALL(gpu, encoder, decoder, ret);
const size_t string_size = apir_decode_array_size_unchecked(decoder);
char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size);
char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size);
if (!string) {
GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__);
return NULL;
@@ -173,7 +173,7 @@ apir_buffer_context_t apir_device_buffer_from_ptr(virtgpu * gpu, size_t size, si
REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR);
if (virtgpu_shmem_create(gpu, size, &buffer_context.shmem)) {
GGML_ABORT(GGML_VIRTGPU "Couldn't allocate the guest-host shared buffer");
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate %ldb of guest-host shared buffer", __func__, size);
}
apir_encode_virtgpu_shmem_res_id(encoder, buffer_context.shmem.res_id);

View File

@@ -1,29 +1,36 @@
#include "virtgpu.h"
#pragma once
// clang-format off
#include "virtgpu.h"
#include "ggml-remoting.h"
#include "backend/shared/apir_backend.h"
#include "backend/shared/apir_cs_ggml.h"
#include "ggml-backend-impl.h"
// clang-format on
#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \
do { \
int32_t forward_flag = (int32_t) apir_command_type__; \
encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, forward_flag); \
if (!encoder_name) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \
} \
#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \
int32_t REMOTE_CALL_PREPARE_forward_flag = (int32_t) apir_command_type__; \
const char * REMOTE_CALL_PREPARE_command_name = apir_dispatch_command_name(apir_command_type__); \
do { \
encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, REMOTE_CALL_PREPARE_forward_flag); \
if (!encoder_name) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \
} \
} while (0)
#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \
do { \
ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \
if (!decoder_name) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \
} \
if (ret_name < APIR_FORWARD_BASE_INDEX) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \
apir_forward_error(ret_name), ret_name); \
} \
ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \
#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \
do { \
ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \
if (!decoder_name) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \
} \
if (ret_name < APIR_FORWARD_BASE_INDEX) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \
apir_forward_error(ret_name), ret_name); \
} \
ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \
if (ret_name != 0) { \
GGML_ABORT(GGML_VIRTGPU "backend function '%s' failed (return code: %d)", \
REMOTE_CALL_PREPARE_command_name, ret_name); \
} \
} while (0)

View File

@@ -20,6 +20,7 @@ apir_buffer_context_t apir_device_buffer_from_ptr(struct virtgpu * gpu,
char * apir_buffer_type_get_name(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
size_t apir_buffer_type_get_alignment(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
size_t apir_buffer_type_get_max_size(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
/* apir_buffer_type_is_host is deprecated. */
apir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu * gpu,
apir_buffer_type_host_handle_t host_handle,
size_t size);

View File

@@ -53,9 +53,9 @@ static int virtgpu_handshake(virtgpu * gpu) {
if (!decoder) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to initiate the communication with the virglrenderer library. "
"Most likely, the wrong virglrenderer library was loaded in the hypervisor.",
__func__);
"%s: failed to initiate the communication with the virglrenderer library. "
"Most likely, the wrong virglrenderer library was loaded in the hypervisor.",
__func__);
return 1;
}
@@ -65,8 +65,7 @@ static int virtgpu_handshake(virtgpu * gpu) {
uint32_t host_minor;
if (ret_magic != APIR_HANDSHAKE_MAGIC) {
GGML_ABORT(GGML_VIRTGPU
"%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic,
GGML_ABORT(GGML_VIRTGPU "%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic,
apir_backend_initialize_error(ret_magic));
} else {
apir_decode_uint32_t(decoder, &host_major);
@@ -140,15 +139,13 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) {
"Make sure virglrenderer is correctly configured by the hypervisor. (%s) ",
__func__, apir_load_library_error(ret));
} else {
GGML_ABORT(GGML_VIRTGPU
"%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)", __func__,
apir_load_library_error(ret), ret);
GGML_ABORT(GGML_VIRTGPU "%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)",
__func__, apir_load_library_error(ret), ret);
}
return ret;
}
GGML_LOG_INFO(GGML_VIRTGPU
"%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__);
GGML_LOG_INFO(GGML_VIRTGPU "%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__);
ApirLoadLibraryReturnCode apir_ret = (ApirLoadLibraryReturnCode) (ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX);
@@ -158,10 +155,11 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) {
"Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
__func__, apir_load_library_error(apir_ret));
} else if (apir_ret == APIR_LOAD_LIBRARY_SYMBOL_MISSING) {
GGML_ABORT(GGML_VIRTGPU
"%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. "
"Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
__func__, apir_load_library_error(apir_ret));
GGML_ABORT(
GGML_VIRTGPU
"%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. "
"Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
__func__, apir_load_library_error(apir_ret));
} else if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) {
GGML_ABORT(GGML_VIRTGPU
"%s: the API Remoting backend library couldn't load the GGML backend library: apir code=%d | %s)",
@@ -169,8 +167,8 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) {
} else {
uint32_t lib_ret = apir_ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX;
GGML_ABORT(GGML_VIRTGPU
"%s: the API Remoting backend library initialize its backend library: apir code=%d)", __func__,
lib_ret);
"%s: the API Remoting backend library failed to initialize its backend library: apir code=%d)",
__func__, lib_ret);
}
return ret;
}
@@ -184,55 +182,49 @@ virtgpu * create_virtgpu() {
// Initialize mutex to protect shared data_shmem buffer
if (mtx_init(&gpu->data_shmem_mutex, mtx_plain) != thrd_success) {
delete gpu;
GGML_ABORT(GGML_VIRTGPU
"%s: failed to initialize data_shmem mutex", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize data_shmem mutex", __func__);
return NULL;
}
if (virtgpu_open(gpu) != APIR_SUCCESS) {
GGML_LOG_ERROR(GGML_VIRTGPU
"%s: failed to open the virtgpu device\n", __func__);
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to open the virtgpu device\n", __func__);
return NULL;
}
if (virtgpu_init_capset(gpu) != APIR_SUCCESS) {
if (gpu->use_apir_capset) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library supports it.", __func__);
"%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library "
"supports it.",
__func__);
} else {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to initialize the virtgpu Venus capset", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu Venus capset", __func__);
}
return NULL;
}
if (virtgpu_init_context(gpu) != APIR_SUCCESS) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to initialize the GPU context", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the GPU context", __func__);
return NULL;
}
if (virtgpu_shmem_create(gpu, SHMEM_REPLY_SIZE, &gpu->reply_shmem)) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to create the shared reply memory pages", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared reply memory pages", __func__);
return NULL;
}
if (virtgpu_shmem_create(gpu, SHMEM_DATA_SIZE, &gpu->data_shmem)) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to create the shared data memory pages", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared data memory pages", __func__);
return NULL;
}
if (virtgpu_handshake(gpu)) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to handshake with the virglrenderer library", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to handshake with the virglrenderer library", __func__);
return NULL;
}
if (virtgpu_load_library(gpu) != APIR_LOAD_LIBRARY_SUCCESS) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to load the backend library", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to load the backend library", __func__);
return NULL;
}
@@ -243,8 +235,7 @@ static virt_gpu_result_t virtgpu_open(virtgpu * gpu) {
drmDevicePtr devs[8];
int count = drmGetDevices2(0, devs, ARRAY_SIZE(devs));
if (count < 0) {
GGML_LOG_ERROR(GGML_VIRTGPU
"%s: failed to enumerate DRM devices\n", __func__);
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to enumerate DRM devices\n", __func__);
return APIR_ERROR_INITIALIZATION_FAILED;
}
@@ -266,19 +257,17 @@ static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr d
int fd = open(node_path, O_RDWR | O_CLOEXEC);
if (fd < 0) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to open %s", __func__, node_path);
GGML_ABORT(GGML_VIRTGPU "%s: failed to open %s", __func__, node_path);
return APIR_ERROR_INITIALIZATION_FAILED;
}
drmVersionPtr version = drmGetVersion(fd);
if (!version || strcmp(version->name, "virtio_gpu") || version->version_major != 0) {
if (version) {
GGML_LOG_ERROR(GGML_VIRTGPU
"%s: unknown DRM driver %s version %d\n", __func__, version->name, version->version_major);
GGML_LOG_ERROR(GGML_VIRTGPU "%s: unknown DRM driver %s version %d\n", __func__, version->name,
version->version_major);
} else {
GGML_LOG_ERROR(GGML_VIRTGPU
"%s: failed to get DRM driver version\n", __func__);
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get DRM driver version\n", __func__);
}
if (version) {
@@ -322,9 +311,8 @@ static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) {
virtgpu_ioctl_get_caps(gpu, gpu->capset.id, gpu->capset.version, &gpu->capset.data, sizeof(gpu->capset.data));
if (ret) {
GGML_LOG_ERROR(GGML_VIRTGPU
"%s: failed to get APIR v%d capset: %s\n",
__func__, gpu->capset.version, strerror(errno));
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get APIR v%d capset: %s\n", __func__, gpu->capset.version,
strerror(errno));
return APIR_ERROR_INITIALIZATION_FAILED;
}
@@ -547,13 +535,10 @@ static void log_call_duration(long long call_duration_ns, const char * name) {
double call_duration_s = (double) call_duration_ns / 1e9; // 1 second = 1e9 nanoseconds
if (call_duration_s > 1) {
GGML_LOG_INFO(GGML_VIRTGPU
"waited %.2fs for the %s host reply...\n", call_duration_s, name);
GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fs for the %s host reply...\n", call_duration_s, name);
} else if (call_duration_ms > 1) {
GGML_LOG_INFO(GGML_VIRTGPU
"waited %.2fms for the %s host reply...\n", call_duration_ms, name);
GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fms for the %s host reply...\n", call_duration_ms, name);
} else {
GGML_LOG_INFO(GGML_VIRTGPU
"waited %lldns for the %s host reply...\n", call_duration_ns, name);
GGML_LOG_INFO(GGML_VIRTGPU "waited %lldns for the %s host reply...\n", call_duration_ns, name);
}
}

View File

@@ -1,5 +1,6 @@
#pragma once
// clang-format off
#include "virtgpu-utils.h"
#include "virtgpu-shm.h"
#include "virtgpu-apir.h"
@@ -23,20 +24,21 @@
#include "apir_hw.h"
#include <drm/virtgpu_drm.h>
#include "venus_hw.h"
// clang-format on
#ifndef VIRTGPU_DRM_CAPSET_APIR
// Will be defined include/drm/virtgpu_drm.h when
// https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590/diffs
// is merged
#define VIRTGPU_DRM_CAPSET_APIR 10
# define VIRTGPU_DRM_CAPSET_APIR 10
#endif
// Mesa/Virlgrenderer Venus internal. Only necessary during the
// Venus->APIR transition in Virglrenderer
#define VENUS_COMMAND_TYPE_LENGTH 331
#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16
#define VIRTGPU_DRM_CAPSET_VENUS 4
#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16
# define VIRTGPU_DRM_CAPSET_VENUS 4
#endif
typedef uint32_t virgl_renderer_capset;

File diff suppressed because it is too large Load Diff

View File

@@ -3,9 +3,13 @@
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#ifdef FLOAT16
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
#endif
#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_shader_subgroup_vote : enable
@@ -15,8 +19,10 @@
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
const uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize;
layout (binding = 0) readonly buffer Q {float data_q[];};
@@ -27,20 +33,22 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
// If SubGroupSize is set to 0 then only use shmem reductions
const uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
shared float tmpsh[tmpsh_size];
shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize];
const uint32_t masksh_stride = Br + 1;
shared FLOAT_TYPE masksh[Bc * masksh_stride];
shared float masksh[Bc][Br];
shared vec4 Qf[Br][HSK / 4];
const uint32_t qf_stride = HSK / 4 + 1;
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
const uint32_t D = HSK > HSV ? HSK : HSV;
const uint32_t kvsh_stride = D / 4 + 1;
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
@@ -50,8 +58,24 @@ void main() {
init_indices();
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
if (LIMIT_OCCUPANCY_SHMEM > 0) {
// This just exists to avoid the occupancy_limiter array getting optimized out
occupancy_limiter[tid] = vec4(tid);
barrier();
if (occupancy_limiter[tid] == vec4(99999.0)) {
data_ov4[0] = D_TYPEV4(occupancy_limiter[tid]);
}
}
#define tile_row(r) (row_tid * rows_per_thread + (r))
uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
@@ -60,37 +84,37 @@ void main() {
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
}
barrier();
vec4 Of[Br][HSV_per_thread / 4];
FLOAT_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = vec4(0.0);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = FLOAT_TYPEV4(0.0);
}
}
float Lf[Br], Mf[Br];
float Lf[rows_per_thread], Mf[rows_per_thread];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
float slope[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = 1.0;
ACC_TYPE slope[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
slope[r] = ACC_TYPE(1.0);
}
// ALiBi
if (p.max_bias > 0.0f) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2);
}
}
@@ -113,75 +137,141 @@ void main() {
uint32_t mask_opt = 0;
uint32_t mask_opt_idx = ~0;
uint32_t mask_opt_bits = 0;
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
if (MASK_ENABLE) {
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
}
mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
// skip this block
continue;
}
// Only load if the block is not all zeros
if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
float max_mask = NEG_FLT_MAX_OVER_2;
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c * masksh_stride + r] = m;
max_mask = max(max_mask, float(m));
} else {
masksh[c * masksh_stride + r] = FLOAT_TYPE(0);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
// skip this block
continue;
}
// Only load if the block is not all zeros
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
max_mask = max(max_mask, m);
ACC_TYPE Sf[rows_per_thread][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = ACC_TYPE(0.0);
}
}
if (SHMEM_STAGING != 0) {
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
}
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
barrier();
}
// More d iterations means Q register caching becomes relevant
// Few iterations means the additional registers needed are worse than the speed-up from caching
if (HSK_per_thread / 4 > 4) {
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
FLOAT_TYPEV4 Q_cache[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
FLOAT_TYPEV4 K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
masksh[c][r] = float(0);
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
float Sf[Br][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
} else {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = 0.0;
}
}
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
FLOAT_TYPEV4 K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
}
}
}
}
@@ -189,89 +279,109 @@ void main() {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split
[[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
}
}
}
if (LOGIT_SOFTCAP) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c]));
}
}
}
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)];
Sf[r][c] += slope[r]*mvf;
}
}
barrier();
}
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
rowmaxf[r] = NEG_FLT_MAX_OVER_2;
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
rowmaxf = max(rowmaxf, float(Sf[r][c]));
}
Moldf[r] = Mf[r];
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Pf[r][c] = exp(Sf[r][c] - Mf[r]);
}
eMf[r] = exp(Moldf[r] - Mf[r]);
// Compute sum across row of P
rowsumf[r] = 0.0;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowsumf[r] += Pf[r][c];
}
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
Mf[r] = max(rowmaxf, Moldf);
eMf[r] = exp(Moldf - Mf[r]);
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = FLOAT_TYPE(eMf[r]) * Of[r][d];
}
}
if (SHMEM_STAGING != 0) {
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSV / 4);
uint32_t c = (idx + tid) / (HSV / 4);
if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
#endif
}
kvsh[c * kvsh_stride + d] = V_Tf;
}
}
barrier();
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
FLOAT_TYPE Pf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r]));
Lf[r] += Pf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
FLOAT_TYPEV4 Vf;
if (SHMEM_STAGING != 0) {
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] += Pf[r][c] * Vf;
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);
}
}
}
barrier();
}
// prevent race on tmpsh
@@ -279,58 +389,115 @@ void main() {
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float rowmaxf, eMf;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = Mf[r];
tmpsh[tid] = Mf[r];
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
if (SubGroupSize > 0) {
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
}
if (row_split == 1) {
// Reduce inside workgroup with shmem
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
}
barrier();
rowmaxf = tmpsh[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
}
}
} else {
barrier();
tmpsh[tid] = rowmaxf;
barrier();
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
if (rowgroup_tid < s) {
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]);
}
barrier();
}
rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid];
}
rowmaxf = tmpsh[d_tid];
barrier();
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf = exp(Moldf - Mf[r]);
float eMf = exp(Moldf - Mf[r]);
Lf[r] = eMf*Lf[r];
tmpsh[tid] = Lf[r];
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
if (SubGroupSize > 0) {
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
Lf[r] += subgroupShuffleXor(Lf[r], s);
}
if (row_split == 1) {
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
}
barrier();
Lf[r] = tmpsh[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
Lf[r] += tmpsh[s * D_split + d_tid];
}
}
} else {
barrier();
}
Lf[r] = tmpsh[d_tid];
barrier();
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = eMf * Of[r][d];
tmpshv4[tid] = Of[r][d];
tmpsh[tid] = Lf[r];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
Of[r][d] += tmpshv4[tid + s];
tmpshv4[tid] = Of[r][d];
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
if (rowgroup_tid < s) {
tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s];
}
barrier();
}
Of[r][d] = tmpshv4[d_tid];
barrier();
Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = FLOAT_TYPE(eMf) * Of[r][d];
if (SubGroupSize > 0) {
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
if (!OLD_AMD_WINDOWS) {
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
} else {
// Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below.
// Shuffle full vec4 as workaround.
// See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697
Of[r][d] += FLOAT_TYPEV4(subgroupShuffleXor(vec4(Of[r][d]), s));
}
}
if (row_split == 1) {
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpshv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
}
barrier();
Of[r][d] = tmpshv4[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
Of[r][d] += tmpshv4[s * D_split + d_tid];
}
}
} else {
barrier();
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
if (rowgroup_tid < s) {
Of[r][d] += tmpshv4[tid ^ s];
tmpshv4[tid] = Of[r][d];
}
barrier();
}
Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid];
}
}
}
@@ -338,33 +505,53 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
if (p.gqa_ratio > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
}
}
}
}
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < N) {
perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
const uint global_row = i * Br + row;
if (global_row < N) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
}
}
if (global_row < N && d_tid == 0 && col_tid == 0) {
uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
}
}
}
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
@@ -373,7 +560,7 @@ void main() {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ms;
Of[r][d] *= FLOAT_TYPE(ms);
}
} else {
vs = exp(sink - Mf[r]);
@@ -383,39 +570,37 @@ void main() {
}
}
float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= FLOAT_TYPE(Lfrcp[r]);
#if defined(FLOAT_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
#endif
}
}
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
}
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (i * Br + r < N) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (i * Br + row < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
}
}
}

View File

@@ -1,20 +1,23 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
layout (constant_id = 9) const uint32_t Flags = 0;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (constant_id = 7) const uint32_t row_split = 1;
layout (constant_id = 8) const uint32_t SubGroupSize = 32;
layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0;
layout (constant_id = 10) const uint32_t Flags = 0;
layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
const bool USE_MASK_OPT = (Flags & 1) != 0;
const bool MASK_ENABLE = (Flags & 2) != 0;
const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
const bool USE_MASK_OPT = (Flags & 1) != 0;
const bool MASK_ENABLE = (Flags & 2) != 0;
const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
const bool OLD_AMD_WINDOWS = (Flags & 8) != 0;
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
const uint32_t HSK_pad = (HSK + 15) & ~15;
@@ -69,6 +72,7 @@ layout (push_constant) uniform parameter {
layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];};
layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
@@ -94,12 +98,12 @@ layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16
#define BLOCK_SIZE 4
#define BLOCK_BYTE_SIZE 16
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
// iqs is currently always zero in the flash attention shaders
if (binding_idx == BINDING_IDX_K) {
return k_packed.k_data_packed[a_offset + ib];
return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);
} else {
return v_packed.v_data_packed[a_offset + ib];
return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);
}
}
#endif
@@ -107,7 +111,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
@@ -115,7 +119,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
vui_lo >>= shift;
vui_hi >>= shift;
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
@@ -123,24 +127,24 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
vui_lo >>= shift;
vui_hi >>= shift;
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
}
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
} else {
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
}
}
#endif
@@ -189,10 +193,16 @@ void init_indices()
KV = p.KV;
if (p.k_num > 1) {
i = 0;
// batch and split_k share gl_WorkGroupID.x
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
split_k_index = gl_WorkGroupID.x % p.k_num;
if (p.gqa_ratio > 1) {
i = 0;
// batch and split_k share gl_WorkGroupID.x
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
split_k_index = gl_WorkGroupID.x % p.k_num;
} else {
gqa_iq1 = 0;
split_k_index = gl_WorkGroupID.x % p.k_num;
i = gl_WorkGroupID.x / p.k_num;
}
} else if (p.gqa_ratio > 1) {
i = 0;
gqa_iq1 = gl_WorkGroupID.x;
@@ -244,3 +254,11 @@ void init_indices()
// Bias applied to softmax to stay in fp16 range.
// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV / 4 + c;
data_ov4[o_offset + offset] = D_TYPEV4(elems);
}

View File

@@ -19,7 +19,6 @@
const uint32_t MatBr = 16;
const uint32_t MatBc = 16;
const uint32_t row_split = Bc / MatBc;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
@@ -33,15 +32,6 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
shared float tmpsh[row_split];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
@@ -54,10 +44,14 @@ shared f16vec4 Psh[Bc * psh_stride];
const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
shared ACC_TYPEV4 sfsh[Bc * sfshstride];
const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4
const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;
const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4
const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
const uint vsh_stride = v_cols;
shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)];
shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
const uint32_t osh_stride = row_split * MatBr / 4;
shared f16vec4 pvsh[MatBc * osh_stride];
shared ACC_TYPE slope[Br];
@@ -84,11 +78,6 @@ void main() {
Qf[i + tid] = f16vec4(0);
}
}
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
if (i + tid < Bc * kshstride) {
ksh[i + tid] = f16vec4(0);
}
}
barrier();
}
@@ -104,10 +93,10 @@ void main() {
}
barrier();
ACC_TYPEV4 Of[rows_per_thread][d_per_thread];
f16vec4 Of[rows_per_thread][d_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
Of[r][d] = ACC_TYPEV4(0.0);
Of[r][d] = f16vec4(0.0);
}
}
@@ -153,22 +142,22 @@ void main() {
uint32_t mask_opt = 0;
uint32_t mask_opt_idx = ~0;
uint32_t mask_opt_bits = 0;
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
[[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
mask_cache[idx] = f16vec4(0);
}
if (MASK_ENABLE) {
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
}
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
// skip this block
continue;
@@ -231,24 +220,24 @@ void main() {
}
}
if (K_LOAD_SHMEM != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
if (SHMEM_STAGING != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK_pad / 4);
uint32_t c = (idx + tid) / (HSK_pad / 4);
if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
}
ksh[c * kshstride + d] = K_Tf;
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
barrier();
@@ -262,7 +251,11 @@ void main() {
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
[[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
if (K_LOAD_SHMEM == 0) {
// If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem
// If not, f16 K is loaded directly from global memory if aligned, otherwise
// staged through a Bc * MatBr size staging buffer.
// If K is not type f16, then it is always staged for dequantization.
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
if (KV_bounds_check || d * 16 + 16 > HSK) {
#endif
@@ -277,13 +270,13 @@ void main() {
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
#endif
}
ksh[row * kshstride + col_vec] = K_Tf;
kvsh[row * kvsh_stride + col_vec] = K_Tf;
}
}
barrier();
@@ -295,8 +288,8 @@ void main() {
if (KV_bounds_check || d * 16 + 16 > HSK)
#endif
{
uint coord = (gl_SubgroupID * MatBc) * kshstride;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
#if BLOCK_SIZE == 1
else {
@@ -305,8 +298,8 @@ void main() {
}
#endif
} else {
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
@@ -329,7 +322,7 @@ void main() {
barrier();
}
if (MASK_ENABLE) {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);
@@ -374,7 +367,7 @@ void main() {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local];
Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local];
}
}
@@ -397,19 +390,47 @@ void main() {
}
}
if (SHMEM_STAGING != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSV_pad / 4);
uint32_t c = (idx + tid) / (HSV_pad / 4);
if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
f16vec4 V_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
#endif
}
kvsh[c * kvsh_stride + d] = V_Tf;
}
}
}
barrier();
const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
// Each subgroup handles HSV/4 columns
[[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
// Preload V tiles for [Bc, 16 * num subgroups]
const uint v_rows = Bc;
const uint v_total = v_rows * v_cols;
const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
// If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem.
// If not, f16 V is loaded directly from global memory if aligned, otherwise
// staged through a Bc * MatBr size staging buffer.
// If V is not type f16, then it is always staged for dequantization.
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
// For f16, only preload if not aligned
if (KV_bounds_check) {
@@ -428,44 +449,52 @@ void main() {
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
#if BLOCK_SIZE > 1
ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
#endif
} else {
ksh[row * vsh_stride + col] = f16vec4(0.0f);
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
}
}
#if BLOCK_SIZE == 1
}
#endif
}
barrier();
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
const uint o_offset = gl_SubgroupID * MatBr / 4;
if (hsv_offset < HSV_pad) {
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
if (!KV_bounds_check) {
// F16 values can be loaded directly from global memory
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
} else
if (!KV_bounds_check) {
// F16 values can be loaded directly from global memory
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
} else
#endif
{
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
{
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
} else {
const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4);
coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
PVMat = coopMatMulAdd(KMat, QMat, PVMat);
}
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
// Store PVMat to pvsh and load into Of
coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
// Store SfMat to sfsh and load into Of
const uint osh_stride = row_split * MatBc / 4;
const uint o_offset = gl_SubgroupID * MatBc / 4;
coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
barrier();
const uint hsv_per_tile = row_split * MatBc;
@@ -484,7 +513,7 @@ void main() {
if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
const uint local_hsv = (hsv_col - hsv_base) / 4;
Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]);
Of[r][d_local] += pvsh[row * osh_stride + local_hsv];
}
}
}
@@ -500,27 +529,48 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
if (p.gqa_ratio > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV/4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV/4) break;
const uint d_local = d0 / threads_per_rowgroup;
gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
}
}
}
}
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
const uint global_row = i * Br + row;
if (global_row < N) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV/4) break;
data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]);
}
}
if (global_row < N && col_tid == 0) {
uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
}
}
}
@@ -539,7 +589,7 @@ void main() {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
Of[r][d_local] *= ACC_TYPE(ms);
Of[r][d_local] *= float16_t(ms);
}
} else {
vs = exp(sink - Mf[r]);
@@ -557,14 +607,14 @@ void main() {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d_local] *= ACC_TYPE(Lfrcp[r]);
#if defined(ACC_TYPE_MAX)
Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX);
Of[r][d_local] *= float16_t(Lfrcp[r]);
#if defined(FLOAT_TYPE_MAX)
Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
#endif
}
}
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
@@ -573,9 +623,7 @@ void main() {
const uint d = d0 + col_tid;
if (d >= HSV / 4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
}
gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
}
}
}
@@ -586,9 +634,7 @@ void main() {
const uint d = d0 + col_tid;
if (d >= HSV / 4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]);
}
data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]);
}
}
}

View File

@@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem;
}
// Store O values for non-GQA split_k. Rows are tokens, not heads.
D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) {
uint32_t global_row = i * Br + r;
if (global_row < N && c < HSV) {
uint32_t o_off = HSV * p.ne1
* (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[o_off + iq2 * HSV + c] = D_TYPE(elem);
}
return elem;
}
// Store L/M values for non-GQA split_k.
ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) {
uint32_t global_row = i * Br + r;
if (global_row < N && c == 0) {
uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3
+ p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[lm_off + lm_base + iq2] = D_TYPE(elem);
}
return elem;
}
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
@@ -290,13 +312,19 @@ void main() {
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
if (p.gqa_ratio > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
} else {
coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N);
coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N);
coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N);
}
return;
}

View File

@@ -595,8 +595,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
void process_shaders() {
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
// matmul
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
// No coopmats
@@ -622,49 +620,63 @@ void process_shaders() {
}
}
// flash attention
for (const auto& f16acc : {false, true}) {
std::map<std::string, std::string> fa_base_dict = base_dict;
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
if (f16acc) {
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
for (const bool& fp16 : {false, true}) {
std::map<std::string, std::string> base_dict;
if (fp16) {
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
} else {
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
}
for (const auto& tname : type_names) {
if (tname == "bf16") continue;
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
} else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
// flash attention
for (const bool& f16acc : {false, true}) {
std::map<std::string, std::string> fa_base_dict = base_dict;
fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
if (fp16 && f16acc) {
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
}
for (const auto& tname : type_names) {
if (tname == "bf16") continue;
if (fp16) {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
} else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
}
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
}
#endif
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
}
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
}
}
}
}
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
for (const auto& tname : type_names) {
// mul mat vec
std::string data_a_key = "DATA_A_" + to_uppercase(tname);

View File

@@ -1,12 +1,19 @@
ggml_add_backend_library(ggml-zendnn
ggml-zendnn.cpp)
# Get ZenDNN path
if (NOT DEFINED ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "")
set(ZENDNN_ROOT "$ENV{ZENDNN_ROOT}")
endif()
# Check if path is still empty or OFF
if (BUILD_SHARED_LIBS)
set(ZENDNN_SHARED_LIB ON)
set(ZENDNN_ARCHIVE_LIB OFF)
else()
set(ZENDNN_SHARED_LIB OFF)
set(ZENDNN_ARCHIVE_LIB ON)
endif()
# Download and build ZenDNN if not provided
if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
message(STATUS "ZENDNN_ROOT not set. Automatically downloading and building ZenDNN...")
message(STATUS "This will take several minutes on first build...")
@@ -21,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
ExternalProject_Add(
zendnn
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
GIT_TAG 21ce8f7879c86bf3637f707fae6f29e0951db5fe
GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08
PREFIX ${ZENDNN_PREFIX}
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
BINARY_DIR ${ZENDNN_BUILD_DIR}
@@ -32,7 +39,9 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
-DZENDNNL_BUILD_DOXYGEN=OFF
-DZENDNNL_BUILD_GTEST=OFF
-DZENDNNL_BUILD_BENCHDNN=OFF
# Enable ALL matmul algorithm backends
-DZENDNNL_DEPENDS_FBGEMM=OFF
-DZENDNNL_LIB_BUILD_ARCHIVE=${ZENDNN_ARCHIVE_LIB}
-DZENDNNL_LIB_BUILD_SHARED=${ZENDNN_SHARED_LIB}
-DZENDNNL_DEPENDS_AOCLDLP=ON
-DZENDNNL_DEPENDS_ONEDNN=ON
-DZENDNNL_DEPENDS_LIBXSMM=ON
@@ -45,47 +54,37 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
LOG_INSTALL ON
)
# Add dependency so ZenDNN builds before our library
add_dependencies(ggml-zendnn zendnn)
# Set ZENDNN_ROOT to the installation directory
set(ZENDNN_ROOT ${ZENDNN_INSTALL_DIR})
message(STATUS "ZenDNN will be built to: ${ZENDNN_ROOT}")
else()
message(STATUS "Using custom ZenDNN installation at: ${ZENDNN_ROOT}")
endif()
# ZenDNN headers + libs
target_include_directories(ggml-zendnn PRIVATE
${ZENDNN_ROOT}/zendnnl/include
${ZENDNN_ROOT}/deps/aocldlp/include
${ZENDNN_ROOT}/deps/aoclutils/include
${ZENDNN_ROOT}/deps/json/include
${ZENDNN_ROOT}/deps/libxsmm/include
${ZENDNN_ROOT}/deps/aoclutils/include
${ZENDNN_ROOT}/deps/aocldlp/include
${ZENDNN_ROOT}/deps/onednn/include
)
${ZENDNN_ROOT}/deps/libxsmm/include)
target_link_directories(ggml-zendnn PRIVATE
${ZENDNN_ROOT}/zendnnl/lib
${ZENDNN_ROOT}/deps/aocldlp/lib
${ZENDNN_ROOT}/deps/aoclutils/lib
${ZENDNN_ROOT}/deps/libxsmm/lib
${ZENDNN_ROOT}/deps/onednn/lib
)
if (ZENDNN_SHARED_LIB)
target_link_directories(ggml-zendnn PRIVATE ${ZENDNN_ROOT}/zendnnl/lib)
target_link_libraries(ggml-zendnn PRIVATE zendnnl)
elseif (ZENDNN_ARCHIVE_LIB)
target_link_libraries(ggml-zendnn PRIVATE
${ZENDNN_ROOT}/zendnnl/lib/libzendnnl_archive.a
${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libaoclutils.a
${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libau_cpuid.a
${ZENDNN_ROOT}/deps/aocldlp/lib/libaocl-dlp.a
${ZENDNN_ROOT}/deps/onednn/${CMAKE_INSTALL_LIBDIR}/libdnnl.a
${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmm.a
${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmext.a
${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmnoblas.a)
endif()
target_link_libraries(ggml-zendnn PRIVATE
zendnnl_archive # ZenDNN main
aocl-dlp # AOCL libraries
aoclutils
au_cpuid
dnnl # OneDNN
xsmm # libxsmm small matrix math
xsmmext
xsmmnoblas
m
pthread
)
target_link_libraries(ggml-zendnn PRIVATE m pthread)
if (GGML_OPENMP)
target_link_libraries(ggml-zendnn PRIVATE OpenMP::OpenMP_CXX)

View File

@@ -41,13 +41,13 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
const TA * A, int64_t lda, const TB * B, int64_t ldb, TC * C,
int64_t ldc) {
zendnnl::lowoha::lowoha_params params;
zendnnl::lowoha::matmul::matmul_params params;
params.dtypes.src = ggml_to_zendnn_type<TB>();
params.dtypes.wei = ggml_to_zendnn_type<TA>();
params.dtypes.dst = ggml_to_zendnn_type<TC>();
params.num_threads = ctx->n_threads;
zendnnl::lowoha::status_t status = zendnnl::lowoha::matmul_direct(
zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct(
'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major)
n, // M: rows of B and C
m, // N: cols of A^T and C
@@ -63,7 +63,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
params // params
);
if (status != zendnnl::lowoha::status_t::success) {
if (status != zendnnl::error_handling::status_t::success) {
GGML_LOG_ERROR("%s, ZenDNN matmul failed: status=%d\n", __func__, static_cast<int>(status));
return false;
}

View File

@@ -899,7 +899,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
};
const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
GGML_ASSERT(type < GGML_TYPE_COUNT);
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
return &type_traits[type];
}
@@ -1265,27 +1266,33 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
}
int64_t ggml_blck_size(enum ggml_type type) {
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
return type_traits[type].blck_size;
}
size_t ggml_type_size(enum ggml_type type) {
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
return type_traits[type].type_size;
}
size_t ggml_row_size(enum ggml_type type, int64_t ne) {
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
assert(ne % ggml_blck_size(type) == 0);
return ggml_type_size(type)*ne/ggml_blck_size(type);
}
double ggml_type_sizef(enum ggml_type type) {
return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
}
const char * ggml_type_name(enum ggml_type type) {
return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE";
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
return type_traits[type].type_name;
}
bool ggml_is_quantized(enum ggml_type type) {
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
return type_traits[type].is_quantized;
}
@@ -1629,11 +1636,23 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml
const size_t cur_end = cur_offs + cur_size;
// align to GGML_MEM_ALIGN
GGML_ASSERT(size <= SIZE_MAX - (GGML_MEM_ALIGN - 1));
size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
char * const mem_buffer = ctx->mem_buffer;
struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
// integer overflow checks
if (cur_end > SIZE_MAX - size_needed) {
GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu)\n", __func__, cur_end, size_needed);
return NULL;
}
if (cur_end + size_needed > SIZE_MAX - GGML_OBJECT_SIZE) {
GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu) + GGML_OBJECT_SIZE (%zu)\n", __func__,
cur_end, size_needed, (size_t) GGML_OBJECT_SIZE);
return NULL;
}
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
__func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
@@ -1702,6 +1721,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
obj_alloc_size = data_size;
}
GGML_ASSERT(GGML_TENSOR_SIZE <= SIZE_MAX - obj_alloc_size);
struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
GGML_ASSERT(obj_new);

View File

@@ -15,6 +15,17 @@
#include <string>
#include <vector>
#define GGUF_MAX_STRING_LENGTH (1024*1024*1024)
#define GGUF_MAX_ARRAY_ELEMENTS (1024*1024*1024)
#ifdef _WIN32
# define gguf_ftell _ftelli64
# define gguf_fseek _fseeki64
#else
# define gguf_ftell ftello
# define gguf_fseek fseeko
#endif
template <typename T>
struct type_to_gguf_type;
@@ -217,17 +228,64 @@ struct gguf_context {
};
struct gguf_reader {
FILE * file;
gguf_reader(FILE * file) : file(file) {
// read the remaining bytes once and update on each read
nbytes_remain = file_remain(file);
}
gguf_reader(FILE * file) : file(file) {}
// helper for remaining bytes in a file
static uint64_t file_remain(FILE * file) {
const int64_t cur = gguf_ftell(file);
if (cur < 0) {
return 0;
}
if (gguf_fseek(file, 0, SEEK_END) != 0) {
gguf_fseek(file, cur, SEEK_SET);
return 0;
}
const int64_t end = gguf_ftell(file);
if (end < 0) {
gguf_fseek(file, cur, SEEK_SET);
return 0;
}
gguf_fseek(file, cur, SEEK_SET);
return static_cast<uint64_t>(end - cur);
}
template <typename T>
bool read(T & dst) const {
return fread(&dst, 1, sizeof(dst), file) == sizeof(dst);
const size_t size = sizeof(dst);
if (nbytes_remain < size) {
return false;
}
const size_t nread = fread(&dst, 1, size, file);
nbytes_remain -= nread;
return nread == size;
}
template <typename T>
bool read(std::vector<T> & dst, const size_t n) const {
if (n > GGUF_MAX_ARRAY_ELEMENTS) {
return false;
}
if constexpr (std::is_same<T, std::string>::value) {
// strings are prefixed with their length, so we need to account for that
if (n > SIZE_MAX / sizeof(uint64_t)) {
return false;
}
if (nbytes_remain < n * sizeof(uint64_t)) {
return false;
}
} else {
if (n > SIZE_MAX / sizeof(T)) {
return false;
}
if (nbytes_remain < n * sizeof(T)) {
return false;
}
}
dst.resize(n);
for (size_t i = 0; i < dst.size(); ++i) {
if constexpr (std::is_same<T, bool>::value) {
@@ -277,13 +335,33 @@ struct gguf_reader {
if (!read(size)) {
return false;
}
dst.resize(size);
return fread(dst.data(), 1, dst.length(), file) == dst.length();
if (size > GGUF_MAX_STRING_LENGTH) {
GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds maximum %" PRIu64 "\n", __func__, size, (uint64_t) GGUF_MAX_STRING_LENGTH);
return false;
}
if (size > nbytes_remain) {
GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds remaining file size %" PRIu64 " bytes\n", __func__, size, nbytes_remain);
return false;
}
dst.resize(static_cast<size_t>(size));
const size_t nread = fread(dst.data(), 1, size, file);
nbytes_remain -= nread;
return nread == size;
}
bool read(void * dst, const size_t size) const {
return fread(dst, 1, size, file) == size;
if (size > nbytes_remain) {
return false;
}
const size_t nread = fread(dst, 1, size, file);
nbytes_remain -= nread;
return nread == size;
}
private:
FILE * file;
mutable uint64_t nbytes_remain;
};
struct gguf_context * gguf_init_empty(void) {
@@ -568,8 +646,8 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// check that tensor type is within defined range
if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n",
__func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d. should be in [0, %d)\n",
__func__, info.t.name, info.t.type, GGML_TYPE_COUNT);
ok = false;
break;
}
@@ -618,14 +696,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors);
// we require the data section to be aligned, so take into account any padding
if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
if (gguf_fseek(file, GGML_PAD(gguf_ftell(file), ctx->alignment), SEEK_SET) != 0) {
GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__);
gguf_free(ctx);
return nullptr;
}
// store the current file offset - this is where the data section starts
ctx->offset = ftell(file);
ctx->offset = gguf_ftell(file);
// compute the total size of the data section, taking into account the alignment
{
@@ -657,10 +735,34 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// the ggml_tensor structs to the appropriate locations in the binary blob
// compute the exact size needed for the new ggml_context
const size_t mem_size =
params.no_alloc ?
(n_tensors )*ggml_tensor_overhead() :
(n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
size_t mem_size = 0;
if (params.no_alloc) {
if (n_tensors != 0 && SIZE_MAX / n_tensors < ggml_tensor_overhead()) {
GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
gguf_free(ctx);
return nullptr;
}
const size_t overhead = n_tensors * ggml_tensor_overhead();
mem_size = overhead;
} else {
if ((n_tensors + 1) != 0 && SIZE_MAX / (n_tensors + 1) < ggml_tensor_overhead()) {
GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
gguf_free(ctx);
return nullptr;
}
const size_t overhead = (n_tensors + 1) * ggml_tensor_overhead();
if (SIZE_MAX - overhead < ctx->size) {
GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
gguf_free(ctx);
return nullptr;
}
mem_size = overhead + ctx->size;
}
struct ggml_init_params pdata = {
/*mem_size =*/ mem_size,

View File

@@ -379,6 +379,7 @@ class MODEL_ARCH(IntEnum):
NEO_BERT = auto()
JINA_BERT_V2 = auto()
JINA_BERT_V3 = auto()
EUROBERT = auto()
BLOOM = auto()
STABLELM = auto()
QWEN = auto()
@@ -531,6 +532,7 @@ class MODEL_TENSOR(IntEnum):
FFN_GATE_EXP = auto()
FFN_DOWN_EXP = auto()
FFN_UP_EXP = auto()
FFN_GATE_UP_EXP = auto()
FFN_GATE_SHEXP = auto()
FFN_DOWN_SHEXP = auto()
FFN_UP_SHEXP = auto()
@@ -820,6 +822,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.NEO_BERT: "neo-bert",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
MODEL_ARCH.JINA_BERT_V3: "jina-bert-v3",
MODEL_ARCH.EUROBERT: "eurobert",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen",
@@ -978,6 +981,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
@@ -1587,6 +1591,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.EUROBERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_DOWN,
],
MODEL_ARCH.MPT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -1805,6 +1822,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_GATE_UP_EXP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
@@ -1894,6 +1912,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_GATE_UP_EXP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
@@ -2595,6 +2614,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,

View File

@@ -175,6 +175,9 @@ class GGUFReader:
if new_align.types != [GGUFValueType.UINT32]:
raise ValueError('Bad type for general.alignment field')
self.alignment = new_align.parts[-1][0]
# Ensure alignment is a non-zero power of two
if self.alignment == 0 or (self.alignment & (self.alignment - 1)) != 0:
raise ValueError('Invalid alignment: must be a non-zero power of two')
padding = offs % self.alignment
if padding != 0:
offs += self.alignment - padding
@@ -202,11 +205,11 @@ class GGUFReader:
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
if field.name in self.fields:
# TODO: add option to generate error on duplicate keys
# raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
# TODO: add option to make this a warning and accept duplicate keys like below
raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
self.fields[field.name + '_{}'.format(field.offset)] = field
# logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
# self.fields[field.name + '_{}'.format(field.offset)] = field
else:
self.fields[field.name] = field
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)

View File

@@ -501,6 +501,8 @@ class GGUFWriter:
self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
def add_custom_alignment(self, alignment: int) -> None:
if alignment <= 0 or (alignment & (alignment - 1)) != 0:
raise ValueError('Invalid alignment: must be a non-zero power of two')
self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment)

View File

@@ -567,6 +567,10 @@ class TensorNameMap:
"model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe
),
MODEL_TENSOR.FFN_GATE_UP_EXP: (
"model.layers.{bid}.mlp.experts.gate_up_proj",
),
# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox

View File

@@ -25,16 +25,12 @@ Example usage:
"""
def generate_input_prompt(length: int) -> list[str]:
CORPUS = """
You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls.
### Tool Call Format:
When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text.
You can make multiple calls in one go by placing them one after another.
"""
words = [w.strip() for w in CORPUS.strip().split(" ")]
def get_remote_corpus(url: str, length: int) -> list[str]:
response = requests.get(url)
response.raise_for_status()
corpus = response.text
words = [w.strip() for w in corpus.strip().split(" ")]
words = [w for w in words if "<" not in w] # make sure nothing looks like special tokens
words = [w for w in words if len(w) > 0] # filter out empty strings
while len(words) < length:
words += words
@@ -226,9 +222,9 @@ def parse_args() -> argparse.Namespace:
)
parser_dump.add_argument(
"--file",
type=Path,
default=None,
help="File containing prompt to use instead of the default",
type=str,
default="https://raw.githubusercontent.com/ggml-org/llama.cpp/eaba92c3dcc980ebe753348855d4a5d75c069997/tools/server/README.md",
help="File containing prompt to use instead of the default (can also be an URL)",
)
parser_dump.add_argument(
"--pattern",
@@ -259,17 +255,19 @@ def main():
if args.verb == "dump":
pattern = parse_pattern(args.pattern)
input_length = sum(n for _, n in pattern)
input_words = generate_input_prompt(input_length)
if args.file is not None:
with args.file.open("r") as f:
required_words = sum(n for _, n in pattern)
if args.file.startswith("http"):
input_words = get_remote_corpus(args.file, required_words)
logger.info(f"Fetched {len(input_words)} words from remote {args.file}")
else:
with open(args.file, "r") as f:
input_words = f.read().strip().split(" ")
if input_length < sum(n for _, n in pattern):
input_words = [w for w in input_words if len(w) > 0] # filter out empty strings
if len(input_words) < required_words:
raise ValueError(
f"Input file has only {input_length} words, but pattern requires at least {input_length} words."
f"Input file has only {len(input_words)} words, but pattern requires at least {required_words} words."
)
input_length = len(input_words)
logger.info(f"Using {input_length} words")
logger.info(f"Using {len(input_words)} words")
dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
elif args.verb == "compare":
compare_logits(args.input1, args.input2, args.output)

View File

@@ -62,6 +62,7 @@ add_library(llama
models/dream.cpp
models/ernie4-5-moe.cpp
models/ernie4-5.cpp
models/eurobert.cpp
models/exaone-moe.cpp
models/exaone.cpp
models/exaone4.cpp

View File

@@ -26,6 +26,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_NEO_BERT, "neo-bert" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" },
{ LLM_ARCH_EUROBERT, "eurobert" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
{ LLM_ARCH_QWEN, "qwen" },
@@ -348,6 +349,7 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_GATE_UP_EXPS, "blk.%d.ffn_gate_up_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
@@ -819,6 +821,20 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_CLS,
LLM_TENSOR_CLS_OUT,
};
case LLM_ARCH_EUROBERT:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_DOWN,
};
case LLM_ARCH_MODERN_BERT:
return {
LLM_TENSOR_TOKEN_EMBD,
@@ -989,6 +1005,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_UP_EXPS,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_DOWN_SHEXP,
@@ -1046,6 +1063,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_UP_EXPS,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_DOWN_SHEXP,
@@ -1586,6 +1604,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_UP_EXPS,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_DOWN_SHEXP,
@@ -2670,6 +2689,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_GATE_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},

View File

@@ -30,6 +30,7 @@ enum llm_arch {
LLM_ARCH_NEO_BERT,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_JINA_BERT_V3,
LLM_ARCH_EUROBERT,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
LLM_ARCH_QWEN,
@@ -372,6 +373,7 @@ enum llm_tensor {
LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_UP_EXPS,
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,

View File

@@ -1165,7 +1165,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
float w_scale,
llama_expert_gating_func_type gating_op,
int il,
ggml_tensor * probs_in) const {
ggml_tensor * probs_in,
ggml_tensor * gate_up_exps) const {
return build_moe_ffn(
cur,
gate_inp, /* gate_inp_b */ nullptr,
@@ -1181,7 +1182,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
w_scale,
gating_op,
il,
probs_in
probs_in,
gate_up_exps
);
}
@@ -1204,7 +1206,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
float w_scale,
llama_expert_gating_func_type gating_op,
int il,
ggml_tensor * probs_in) const {
ggml_tensor * probs_in,
ggml_tensor * gate_up_exps,
ggml_tensor * gate_up_exps_b) const {
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -1343,26 +1347,48 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(cur, "ffn_moe_weighted", il);
}
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
if (up_exps_b) {
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
cb(up, "ffn_moe_up_biased", il);
}
ggml_tensor * up = nullptr;
ggml_tensor * experts = nullptr;
if (gate_exps) {
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
if (gate_up_exps) {
// merged gate_up path: one mul_mat_id, then split into gate and up views
ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens]
cb(gate_up, "ffn_moe_gate_up", il);
if (gate_up_exps_b) {
gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts);
cb(gate_up, "ffn_moe_gate_up_biased", il);
}
const int64_t n_ff = gate_up->ne[0] / 2;
cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
cb(cur, "ffn_moe_gate", il);
up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]);
cb(up, "ffn_moe_up", il);
} else {
cur = up;
// separate gate and up path
up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
if (up_exps_b) {
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
cb(up, "ffn_moe_up_biased", il);
}
if (gate_exps) {
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(cur, "ffn_moe_gate", il);
} else {
cur = up;
}
if (gate_exps_b) {
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
cb(cur, "ffn_moe_gate_biased", il);
}
}
if (gate_exps_b) {
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
cb(cur, "ffn_moe_gate_biased", il);
}
const bool has_gate = gate_exps || gate_up_exps;
switch (type_op) {
case LLM_FFN_SILU:
@@ -1385,7 +1411,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
break;
}
}
}
if (has_gate) {
cur = ggml_swiglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_swiglu", il);
} else {
@@ -1393,7 +1421,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(cur, "ffn_moe_silu", il);
} break;
case LLM_FFN_GELU:
if (gate_exps) {
if (has_gate) {
cur = ggml_geglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_geglu", il);
} else {
@@ -1409,7 +1437,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(cur, "ffn_moe_swiglu_oai", il);
} break;
case LLM_FFN_RELU:
if (gate_exps) {
if (has_gate) {
cur = ggml_reglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_reglu", il);
} else {
@@ -1417,7 +1445,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(cur, "ffn_moe_relu", il);
} break;
case LLM_FFN_RELU_SQR:
if (gate_exps) {
if (has_gate) {
// TODO: add support for gated squared relu
GGML_ABORT("fatal error: gated squared relu not implemented");
} else {

View File

@@ -814,7 +814,8 @@ struct llm_graph_context {
float w_scale,
llama_expert_gating_func_type gating_op,
int il,
ggml_tensor * probs_in = nullptr) const;
ggml_tensor * probs_in = nullptr,
ggml_tensor * gate_up_exps = nullptr) const;
ggml_tensor * build_moe_ffn(
ggml_tensor * cur,
@@ -835,7 +836,9 @@ struct llm_graph_context {
float w_scale,
llama_expert_gating_func_type gating_op,
int il,
ggml_tensor * probs_in = nullptr) const;
ggml_tensor * probs_in = nullptr,
ggml_tensor * gate_up_exps = nullptr,
ggml_tensor * gate_up_exps_b = nullptr) const;
//
// inputs

View File

@@ -978,6 +978,9 @@ bool llama_kv_cache::get_can_shift() const {
if (model.arch == LLM_ARCH_STEP35) {
return false;
}
if (hparams.n_pos_per_embd() > 1) {
return false;
}
return true;
}

View File

@@ -163,7 +163,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
const auto & cell = cells[tail_id];
// partial intersection is invalid if it includes the final pos
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
return false;
}
// invalidate tails which will be cleared

View File

@@ -123,6 +123,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_8B_A1B: return "8B.A1B";
case LLM_TYPE_16B_A1B: return "16B.A1B";
case LLM_TYPE_21B_A3B: return "21B.A3B";
case LLM_TYPE_24B_A2B: return "24B.A2B";
case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_31B_A3_5B: return "31B.A3.5B";
case LLM_TYPE_35B_A3B: return "35B.A3B";
@@ -978,6 +979,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
type = LLM_TYPE_250M;
}
} break;
case LLM_ARCH_EUROBERT:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
if (hparams.n_layer == 12) {
type = LLM_TYPE_SMALL; // 0.2B
}
} break;
case LLM_ARCH_BLOOM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2381,7 +2392,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
}
type = LLM_TYPE_8B_A1B;
switch (hparams.n_layer) {
case 24: type = LLM_TYPE_8B_A1B; break;
case 40: type = LLM_TYPE_24B_A2B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_SMALLTHINKER:
{
@@ -2965,6 +2980,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// TODO: move to a separate function
const auto tn = LLM_TN(arch);
// helper: try merged gate_up_exps first, fall back to separate gate and up
auto create_tensor_gate_up_exps = [&](llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) {
layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED);
if (layer.ffn_gate_up_exps == nullptr) {
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags);
}
};
switch (arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_REFACT:
@@ -3565,6 +3589,29 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
}
} break;
case LLM_ARCH_EUROBERT:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
}
} break;
case LLM_ARCH_JINA_BERT_V2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
@@ -5183,9 +5230,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
// MoE branch
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
// Shared expert branch
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
@@ -7387,9 +7433,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
// Shared experts
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
@@ -7453,9 +7498,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
// Shared experts
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
@@ -8176,6 +8220,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_NEO_BERT:
case LLM_ARCH_EUROBERT:
case LLM_ARCH_WAVTOKENIZER_DEC:
case LLM_ARCH_MODERN_BERT:
case LLM_ARCH_GEMMA_EMBEDDING:
@@ -8373,6 +8418,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_neo_bert>(*this, params);
} break;
case LLM_ARCH_EUROBERT:
{
llm = std::make_unique<llm_build_eurobert>(*this, params);
} break;
case LLM_ARCH_BLOOM:
{
llm = std::make_unique<llm_build_bloom>(*this, params);
@@ -8999,6 +9048,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_MODERN_BERT:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_EUROBERT:
case LLM_ARCH_STABLELM:
case LLM_ARCH_BITNET:
case LLM_ARCH_QWEN:

View File

@@ -116,6 +116,7 @@ enum llm_type {
LLM_TYPE_8B_A1B, // lfm2moe
LLM_TYPE_16B_A1B,
LLM_TYPE_21B_A3B, // Ernie MoE small
LLM_TYPE_24B_A2B, // lfm2moe
LLM_TYPE_30B_A3B,
LLM_TYPE_31B_A3_5B,
LLM_TYPE_35B_A3B, // Qwen3.5
@@ -279,14 +280,16 @@ struct llama_layer {
struct ggml_tensor * ffn_up_enc = nullptr;
// ff MoE
struct ggml_tensor * ffn_gate_inp = nullptr;
struct ggml_tensor * ffn_gate_exps = nullptr;
struct ggml_tensor * ffn_down_exps = nullptr;
struct ggml_tensor * ffn_up_exps = nullptr;
struct ggml_tensor * ffn_gate_inp_b = nullptr;
struct ggml_tensor * ffn_gate_exps_b = nullptr;
struct ggml_tensor * ffn_down_exps_b = nullptr;
struct ggml_tensor * ffn_up_exps_b = nullptr;
struct ggml_tensor * ffn_gate_inp = nullptr;
struct ggml_tensor * ffn_gate_exps = nullptr;
struct ggml_tensor * ffn_down_exps = nullptr;
struct ggml_tensor * ffn_up_exps = nullptr;
struct ggml_tensor * ffn_gate_up_exps = nullptr;
struct ggml_tensor * ffn_gate_inp_b = nullptr;
struct ggml_tensor * ffn_gate_exps_b = nullptr;
struct ggml_tensor * ffn_down_exps_b = nullptr;
struct ggml_tensor * ffn_up_exps_b = nullptr;
struct ggml_tensor * ffn_gate_up_exps_b = nullptr;
// ff shared expert (shexp)
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;

View File

@@ -1890,7 +1890,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "falcon-h1" ||
tokenizer_pre == "pixtral" ||
tokenizer_pre == "midm-2.0" ||
tokenizer_pre == "lfm2") {
tokenizer_pre == "lfm2" ||
tokenizer_pre == "jina-v5-nano") {
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
ignore_merges = true;
add_bos = true;

View File

@@ -218,7 +218,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
LLM_FFN_SILU, hparams.expert_weights_norm,
hparams.expert_weights_scale, hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
il,
nullptr,
model.layers[il].ffn_gate_up_exps);
cb(moe_out, "ffn_moe_out", il);
// FFN shared expert

97
src/models/eurobert.cpp Normal file
View File

@@ -0,0 +1,97 @@
#include "models.h"
llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
ggml_tensor * inp_pos = build_inp_pos();
inpL = build_inp_embd(model.tok_embd);
cb(inpL, "inp_embd", -1);
auto * inp_attn = build_attn_inp_no_cache();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * cur = inpL;
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
{
ggml_tensor * Qcur;
ggml_tensor * Kcur;
ggml_tensor * Vcur;
Qcur = build_lora_mm(model.layers[il].wq, cur);
Kcur = build_lora_mm(model.layers[il].wk, cur);
Vcur = build_lora_mm(model.layers[il].wv, cur);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, nullptr,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
cb(cur, "kqv_out", il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
cur = ggml_add(ctx0, cur, inpL);
ggml_tensor * ffn_inp = cur;
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_embd", -1);
res->t_embd = cur;
ggml_build_forward_expand(gf, cur);
}

View File

@@ -116,6 +116,8 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Check layer type by checking which tensors exist
// KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor
bool is_kda = (layer.ssm_a != nullptr);

View File

@@ -424,6 +424,10 @@ struct llm_build_neo_bert : public llm_graph_context {
llm_build_neo_bert(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_eurobert : public llm_graph_context {
llm_build_eurobert(const llama_model & model, const llm_graph_params & params);
};
template <bool iswa>
struct llm_build_olmo2 : public llm_graph_context {
llm_build_olmo2(const llama_model & model, const llm_graph_params & params);

View File

@@ -29,6 +29,8 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Determine layer type and build appropriate attention mechanism
if (hparams.is_recurrent(il)) {
// Linear attention layer (gated delta net)
@@ -269,7 +271,6 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
cb(conv_states_all, "conv_states_updated", il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);

View File

@@ -29,6 +29,8 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Determine layer type and build appropriate attention mechanism
if (hparams.is_recurrent(il)) {
// Linear attention layer (gated delta net)
@@ -269,7 +271,6 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
cb(conv_states_all, "conv_states_updated", il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
@@ -379,7 +380,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used, LLM_FFN_SILU,
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
nullptr, model.layers[il].ffn_gate_up_exps);
cb(moe_out, "ffn_moe_out", il);
// Add shared experts if present - following Qwen3Next reference implementation

View File

@@ -21,6 +21,8 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Determine layer type and build appropriate attention mechanism
if (hparams.is_recurrent(il)) {
// Linear attention layer (gated delta net)
@@ -354,7 +356,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
cb(conv_states_all, "conv_states_updated", il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
@@ -478,7 +479,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used, LLM_FFN_SILU,
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
nullptr, model.layers[il].ffn_gate_up_exps);
cb(moe_out, "ffn_moe_out", il);
// Add shared experts if present - following Qwen3Next reference implementation

View File

@@ -48,6 +48,7 @@ enum handcrafted_file_type {
HANDCRAFTED_DATA_NOT_ENOUGH_DATA = 10 + offset_has_data,
HANDCRAFTED_DATA_BAD_ALIGN = 15 + offset_has_data,
HANDCRAFTED_DATA_INCONSISTENT_ALIGN = 20 + offset_has_data,
HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW = 30 + offset_has_data,
HANDCRAFTED_DATA_SUCCESS = 800 + offset_has_data,
HANDCRAFTED_DATA_CUSTOM_ALIGN = 810 + offset_has_data,
};
@@ -84,6 +85,7 @@ static std::string handcrafted_file_type_name(const enum handcrafted_file_type h
case HANDCRAFTED_DATA_NOT_ENOUGH_DATA: return "DATA_NOT_ENOUGH_DATA";
case HANDCRAFTED_DATA_BAD_ALIGN: return "DATA_BAD_ALIGN";
case HANDCRAFTED_DATA_INCONSISTENT_ALIGN: return "DATA_INCONSISTENT_ALIGN";
case HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW: return "DATA_MEM_SIZE_OVERFLOW";
case HANDCRAFTED_DATA_SUCCESS: return "DATA_SUCCESS";
case HANDCRAFTED_DATA_CUSTOM_ALIGN: return "DATA_CUSTOM_ALIGN";
}
@@ -196,6 +198,13 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
tensor_configs = get_tensor_configs(rng);
}
if (hft == HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW) {
tensor_configs.resize(2);
tensor_configs[0] = { GGML_TYPE_I8, { 0x7FFFFFFFFFFFFFC0, 1, 1, 1 } };
tensor_configs[1] = { GGML_TYPE_I8, { 0x7FFFFFFFFFFFFFC0, 1, 1, 1 } };
}
if (hft == HANDCRAFTED_HEADER_BAD_N_TENSORS) {
const uint64_t n_tensors = -1;
helper_write(file, n_tensors);
@@ -397,7 +406,8 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
for (uint32_t i = 1; i < n_dims; ++i) {
ne *= shape[i];
}
offset += GGML_PAD(ggml_row_size(type, ne), alignment);
offset += GGML_PAD(ggml_row_size(type, ne), (uint64_t) alignment);
}
while (ftell(file) % alignment != 0) {
@@ -411,6 +421,9 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
if (hft == HANDCRAFTED_DATA_NOT_ENOUGH_DATA) {
nbytes -= 1;
}
if (hft == HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW) {
nbytes = 32;
}
for (uint64_t i = 0; i < nbytes; ++i) {
const uint8_t random_byte = i % 256;
helper_write(file, random_byte);
@@ -704,6 +717,7 @@ static std::pair<int, int> test_handcrafted_file(const unsigned int seed) {
HANDCRAFTED_DATA_NOT_ENOUGH_DATA,
HANDCRAFTED_DATA_BAD_ALIGN,
HANDCRAFTED_DATA_INCONSISTENT_ALIGN,
HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW,
HANDCRAFTED_DATA_SUCCESS,
HANDCRAFTED_DATA_CUSTOM_ALIGN,
};

View File

@@ -13,7 +13,12 @@ fi
name=$1
input=$2
make -j tests/test-tokenizer-0
# Build using CMake if binary doesn't exist
if [ ! -f ./build/bin/test-tokenizer-0 ]; then
printf "Building test-tokenizer-0 with CMake...\n"
cmake -B build -DLLAMA_BUILD_TESTS=ON
cmake --build build --target test-tokenizer-0 -j
fi
printf "Testing %s on %s ...\n" $name $input
@@ -23,7 +28,7 @@ printf "Tokenizing using (py) Python AutoTokenizer ...\n"
python3 ./tests/test-tokenizer-0.py ./models/tokenizers/$name --fname-tok $input > /tmp/test-tokenizer-0-$name-py.log 2>&1
printf "Tokenizing using (cpp) llama.cpp ...\n"
./tests/test-tokenizer-0 ./models/ggml-vocab-$name.gguf $input > /tmp/test-tokenizer-0-$name-cpp.log 2>&1
./build/bin/test-tokenizer-0 ./models/ggml-vocab-$name.gguf $input > /tmp/test-tokenizer-0-$name-cpp.log 2>&1
cat /tmp/test-tokenizer-0-$name-py.log | grep "tokenized in"
cat /tmp/test-tokenizer-0-$name-cpp.log | grep "tokenized in"

View File

@@ -912,7 +912,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_LAST) {
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
}
auto tim1 = std::chrono::high_resolution_clock::now();
LOG_INF("%s: tokenizing the input ..\n", __func__);

View File

@@ -248,7 +248,7 @@ int32_t mtmd_helper_decode_image_chunk(
int32_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
int32_t i_batch = 0;
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
int32_t n_img_batches = (n_tokens + n_batch - 1) / n_batch;
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
if (mtmd_decode_use_mrope(ctx)) {

View File

@@ -1510,7 +1510,7 @@ version = 1
; If the same key is defined in a specific preset, it will override the value in this global section.
[*]
c = 8192
n-gpu-layer = 8
n-gpu-layers = 8
; If the key corresponds to an existing model on the server,
; this will be used as the default config for that model

View File

@@ -231,19 +231,77 @@ server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) :
server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
}
llama_pos server_tokens::pos_next() const {
llama_pos server_tokens::pos_next(int64_t n_tokens) const {
if (!has_mtmd) {
return tokens.size();
if (n_tokens < 0) {
return tokens.size();
}
return n_tokens;
}
llama_pos res = tokens.size();
if (n_tokens < 0) {
llama_pos res = tokens.size();
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
const auto & chunk = it->second;
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
const auto & chunk = it->second;
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
}
return res;
}
return res;
int64_t idx = 0;
llama_pos pos = 0;
GGML_ASSERT(n_tokens <= (int64_t)tokens.size());
while (idx < n_tokens) {
const auto media_it = map_idx_to_media.find(idx);
if (media_it != map_idx_to_media.end()) {
const auto & chunk = media_it->second;
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
pos += n_pos;
idx += n_tok;
} else {
pos++;
idx++;
}
}
return pos;
}
size_t server_tokens::size_up_to_pos(llama_pos max_pos) const {
if (!has_mtmd) {
return std::min((size_t)(max_pos + 1), tokens.size());
}
size_t idx = 0;
llama_pos pos = 0;
while (idx < tokens.size()) {
const auto media_it = map_idx_to_media.find(idx);
if (media_it != map_idx_to_media.end()) {
const auto & chunk = media_it->second;
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
pos += n_pos;
idx += n_tok;
} else {
pos++;
idx++;
}
if (pos > max_pos) {
break;
}
}
return idx;
}
std::string server_tokens::str() const {

View File

@@ -167,7 +167,12 @@ public:
// for debugging
std::string str() const;
llama_pos pos_next() const;
// the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
llama_pos pos_next(int64_t n_tokens = -1) const;
// number of tokens with position <= max_pos
size_t size_up_to_pos(llama_pos max_pos) const;
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
void push_back(llama_token tok);

View File

@@ -995,9 +995,6 @@ private:
// don't update the cache if the slot's context is empty
update_cache = update_cache && tokens.size() > 0;
// TODO: mtmd does not support prompt cache
update_cache = update_cache && (ret->mctx == nullptr);
if (update_cache) {
SRV_WRN("%s", "updating prompt cache\n");
@@ -1442,7 +1439,7 @@ private:
res->id = slot.task->id;
res->id_slot = slot.id;
res->index = slot.task->index;
res->index = slot.task->index;
// keep copy of last generated text for debugging purposes
if (slots_debug) {
@@ -2282,15 +2279,15 @@ private:
n_past = 0;
}
llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
const auto n_swa = std::max(1, llama_model_n_swa(model));
// the largest pos_min required for a checkpoint to be useful
const auto pos_min_thold = std::max(0, n_past - n_swa);
const auto pos_min_thold = std::max(0, pos_next - n_swa);
// note: disallow with mtmd contexts for now
// https://github.com/ggml-org/llama.cpp/issues/17043
if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
@@ -2341,9 +2338,6 @@ private:
}
if (pos_min > pos_min_thold) {
// TODO: support can be added in the future when corresponding vision models get released
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint
@@ -2364,18 +2358,20 @@ private:
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
}
}
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
pos_next = 0;
n_past = 0;
}
}
@@ -2386,7 +2382,7 @@ private:
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_min > pos_min_thold) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, (float) cur.data.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
} else {
++it;
@@ -2402,7 +2398,7 @@ private:
SLT_WRN(slot, "n_past was set to %d\n", n_past);
}
slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_processed = 0;
slot.prompt.tokens.keep_first(n_past);
@@ -2520,10 +2516,6 @@ private:
}
}
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
// entire prompt has been processed
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
@@ -2536,8 +2528,6 @@ private:
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
slot.init_sampler();
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
@@ -2549,13 +2539,15 @@ private:
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
// yet processed and therefore it is not part of the checkpoint.
if (do_checkpoint) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
@@ -2563,16 +2555,21 @@ private:
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
}
SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
} else {
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
}
}

View File

@@ -339,6 +339,17 @@ static std::map<std::string, std::string> get_headers(const httplib::Request & r
return headers;
}
static std::string build_query_string(const httplib::Request & req) {
std::string qs;
for (const auto & [key, value] : req.params) {
if (!qs.empty()) {
qs += '&';
}
qs += httplib::encode_query_component(key) + "=" + httplib::encode_query_component(value);
}
return qs;
}
// using unique_ptr for request to allow safe capturing in lambdas
using server_http_req_ptr = std::unique_ptr<server_http_req>;
@@ -382,6 +393,7 @@ void server_http_context::get(const std::string & path, const server_http_contex
get_params(req),
get_headers(req),
req.path,
build_query_string(req),
req.body,
req.is_connection_closed
});
@@ -396,6 +408,7 @@ void server_http_context::post(const std::string & path, const server_http_conte
get_params(req),
get_headers(req),
req.path,
build_query_string(req),
req.body,
req.is_connection_closed
});

View File

@@ -36,7 +36,8 @@ using server_http_res_ptr = std::unique_ptr<server_http_res>;
struct server_http_req {
std::map<std::string, std::string> params; // path_params + query_params
std::map<std::string, std::string> headers; // reserved for future use
std::string path; // reserved for future use
std::string path;
std::string query_string; // query parameters string (e.g. "action=save")
std::string body;
const std::function<bool()> & should_stop;

View File

@@ -291,7 +291,9 @@ void server_models::load_models() {
for (const auto & [name, inst] : mapping) {
std::string val;
if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val)) {
models_to_load.push_back(name);
if (common_arg_utils::is_truthy(val)) {
models_to_load.push_back(name);
}
}
}
if ((int)models_to_load.size() > base_params.models_max) {
@@ -697,11 +699,15 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
mapping[name].meta.last_used = ggml_time_ms();
}
SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
std::string proxy_path = req.path;
if (!req.query_string.empty()) {
proxy_path += '?' + req.query_string;
}
auto proxy = std::make_unique<server_http_proxy>(
method,
CHILD_ADDR,
meta->port,
req.path,
proxy_path,
req.headers,
req.body,
req.should_stop,

View File

@@ -204,7 +204,8 @@ task_params server_task::params_from_json_cmpl(
params.cache_prompt = json_value(data, "cache_prompt", defaults.cache_prompt);
params.return_tokens = json_value(data, "return_tokens", false);
params.return_progress = json_value(data, "return_progress", false);
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
auto max_tokens = json_value(data, "max_tokens", defaults.n_predict);
params.n_predict = json_value(data, "n_predict", json_value(data, "max_completion_tokens", max_tokens));
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
@@ -1899,10 +1900,9 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
return nullptr;
}
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
auto & cur = states.emplace_back();
cur = {
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
/*.tokens =*/ prompt.tokens.clone(),
/*.data =*/ std::move(state_data),
/*.checkpoints =*/ prompt.checkpoints,
};

View File

@@ -557,6 +557,8 @@ struct server_prompt_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
int64_t n_tokens;
std::vector<uint8_t> data;
size_t size() const {