Compare commits

...

75 Commits
b6823 ... b6898

Author SHA1 Message Date
Jeff Bolz
d2d931f173 vulkan: disable spirv-opt for rope shaders (#16872) 2025-10-31 08:34:47 +01:00
Masato Nakasaka
2976b0374d vulkan: Fix crash when FP16 mul_mat accumulation is not supported (#16796)
* Experimenting crash fix

* added assert for aborting and fixed comment

* changed to check if a pipeline is empty or not

* Moved function in class definition

* replaced with is_empty

* Modified is_empty to check only unaligned pipelines
2025-10-31 08:18:59 +01:00
Ruben Ortlam
d2a2673dd1 vulkan: fix shmem overrun in mmq id shader (#16873)
* vulkan: fix shmem overrun in mmq id shader

* metal : fix mul_mm_id

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-10-31 08:14:49 +01:00
l3utterfly
13002a0896 ggml-hexagon: respect input size when getting/setting tensor data (#16836)
* respect input size when getting/setting tensor data

allows partial repacking/copying when get tensor size is smaller than the actual tensor

* Removed duplicate repack_mxfp4_mxfp4x4x2 function
2025-10-30 21:46:31 -07:00
Sigbjørn Skjæret
6eb208d17e ci : enable free-disk-space on cuda docker build (#16877) 2025-10-31 00:34:27 +01:00
lhez
9984cbb61d opencl: fix boundary handling for mul_mm (#16875) 2025-10-30 16:00:20 -07:00
RodriMora
ce18efeaf1 convert : update transformers requirements (#16866)
* Update requirements-convert_legacy_llama.txt

Updated requirements to support Qwen3-VL in transformers 4.57.1 version

* Update requirements/requirements-convert_legacy_llama.txt

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-10-30 23:15:03 +01:00
chansikpark
16724b5b68 server : bump request URI max length to 32768 (#16862) 2025-10-30 20:22:23 +02:00
Georgi Gerganov
b52edd2558 server : remove n_past (#16818)
* server : remove n_past

* server : replace slot.n_prompt_tokens() with slot.task->n_tokens()

* server : fixes + clean-up

* cont : fix context shift

* server : add server_tokens::pos_next()

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

* server : fix pos_next() usage

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
2025-10-30 18:42:57 +02:00
Max Krasnyansky
517b7170e1 cpu: introduce chunking for repack matmuls and enable matmul-id chunking on ARM64 (#16833)
Very similar implementation to the flash-attention chunking, with similar benefits.
2025-10-30 09:06:13 -07:00
Shagun Bera
835e918d84 common: fix typo in cli help text (#16864) 2025-10-30 17:47:31 +02:00
JJJYmmm
d261223d24 model: add support for qwen3vl series (#16780)
* support qwen3vl series.

Co-authored-by: Thireus ☠ <Thireus@users.noreply.github.com>
Co-authored-by: yairpatch <yairpatch@users.noreply.github.com>
Co-authored-by: LETS-BEE <LETS-BEE@users.noreply.github.com>

* bugfix: fix the arch check for qwen3vl-moe.

* use build_ffn

* optimize deepstack structure

* optimize deepstack feature saving

* Revert "optimize deepstack feature saving" for temporal fix

This reverts commit f321b9fdf1.

* code clean

* use fused qkv in clip

* clean up / rm is_deepstack_layers for simplification

* add test model

* move test model to "big" section

* fix imrope check

* remove trailing whitespace

* fix rope fail

* metal : add imrope support

* add imrope support for sycl

* vulkan: add imrope w/o check

* fix vulkan

* webgpu: add imrope w/o check

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix tensor mapping

---------

Co-authored-by: Thireus ☠ <Thireus@users.noreply.github.com>
Co-authored-by: yairpatch <yairpatch@users.noreply.github.com>
Co-authored-by: LETS-BEE <LETS-BEE@users.noreply.github.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-10-30 16:19:14 +01:00
Max Krasnyansky
dcca0d3ab8 cpu: introduce chunking for flash attention (#16829)
Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop
on top that handles the chunks.
2025-10-30 14:26:05 +02:00
Tianyue-Zhao
bacddc049a model: Add support for CogVLM model (#15002)
* Added GGUF mappings for CogVLM model

* Add tensor mapping for CogVLM visual encoder

* Add CogVLM to conversion script, no vision part yet

* Added CogVLM vision model to conversion script

* Add graph for CogVLM CLIP model

* Add graph for CogVLM

* Fixes for CogVLM. Now compiles.

* Model now runs

* Fixes for cogvlm graph

* Account for graph context change after rebase

* Changes for whitespace

* Changes in convert script according to comments

* Switch CogVLM LLM graph to merged QKV tensor

* Use rope_type variable instead of direct definition

* Change CogVLM CLIP encoder to use SWIGLU

* Switch CogVLM CLIP to use merged QKV

* Apply rebase edits and remove ggml_cont call that is now unnecessary

* clean up

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-10-30 12:18:50 +01:00
Sigbjørn Skjæret
229bf68628 cuda : fix argsort with 64k+ rows (#16849) 2025-10-30 08:56:28 +01:00
Jan Boon
d7395115ba llama : use std::abs instead of abs (#16853) 2025-10-30 08:30:58 +02:00
Jeff Bolz
052df28b0e vulkan: Handle argsort with a large number of rows (#16851) 2025-10-30 07:27:41 +01:00
Oliver Simons
8b11deea46 Hide latency of bias and gate-loading (#16847)
This is realised by loading them into registers before computation of
the dot-product, effectively batching them together with said
dot-product. As a lot of threads are alive here, the warp scheduler has
enough threads available to effectively hide the cost of additionally
loading those two floats.
2025-10-30 11:34:15 +08:00
Jeff Bolz
b9ce940177 vulkan: Fuse rope+set_rows (#16769)
This pattern appears in a lot of models, the rope operation is applied right
before storing into the KV cache (usually on the K tensor).

Add a path to some of the rope shaders that computes the destination address
based on the set_rows tensor. Compile variants of the shader with D_TYPE of
f16 (the usual KV cache type).

Add a src3 operand to ggml_vk_op_f32 - sometimes rope uses three srcs and needs
the fourth for the row indices.

Add fused_ops_write_mask to indicate which intermediate tensors need to write
their results to memory. Skipping writing the roped K value helps to allow more
nodes to run concurrently.

Add logic to ggml_vk_graph_optimize to make ROPE+VIEW+SET_ROWS consecutive. It
rarely starts out that way in the graph.

Add new backend tests.
2025-10-29 15:13:10 -05:00
Xuan-Son Nguyen
3464bdac37 llama: fix ASAN error with M-RoPE (#16848) 2025-10-29 20:11:39 +01:00
Xuan-Son Nguyen
e3af5563bd llama: store mrope data in KV cell (#16825)
* llama: store mrope data in KV cell

* correct x,y ordering

* address review comments

* add consistency checks

* Update src/llama-kv-cache.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* add TODO

* fix asan error

* kv-cells : improve ext handling

* cont : fix headers

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-10-29 18:09:18 +01:00
Jeff Bolz
10fcc41290 vulkan: Update topk_moe fusion to handle gpt's late softmax (#16656)
* vulkan: Update topk_moe fusion to handle gpt's late softmax

Based on #16649.

* Add ggml_check_edges

* Add sync logging to show fusion effects

* handle clamp added in #16655

* Update ggml/src/ggml-impl.h

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-10-29 14:44:29 +01:00
Ruben Ortlam
bcf5bda6f5 Vulkan MMQ Integer Dot Refactor and K-Quant support (#16536)
* vulkan: add mmq q2_k integer dot support

* Refactor mmq caching

* Reduce mmq register use

* Load 4 quant blocks into shared memory in one step

* Pack q2_k blocks into caches of 32

* Use 32-bit accumulators for integer dot matmul

* Add q4_k mmq

* Add q3_k mmq

* Add q5_k mmq

* Add q6_k mmq

* Add mxfp4 mmq, enable MMQ MUL_MAT_ID

* Fix mmv dm loads
2025-10-29 14:39:03 +01:00
Max Krasnyansky
3eb2be1ca5 Hexagon Op queue & dispatch optimizations (#16820)
* hexagon: remove dspqueue callbacks and do all read processing inplace

* hexagon: there is no need to ref/deref the buffers at this point

We're not going to release the buffers without flushing the session queue.
So there is no need to inc/dec the refcounts for every request.
We also don't need to include those bufs in the response.

* hexagon: bump the thread count in the adb wrapper scripts

We can use more CPU cores now that the dedicated dspqueue polling threads are not used (ie no contention).
Also enable more agressive polling for now since we still map Flash Attention (and a few other kernels) to
the CPU and those dspqueue threads were keeping the CPU cores are higher clock freqs.

* hexagon: add lhez as the second code owner
2025-10-29 06:29:12 -07:00
Aman Gupta
e41bcce8f0 CUDA: use fastdiv in set-rows (#16834)
* CUDA: use fastdiv in set-rows

* add assert about value fitting in u32
2025-10-29 21:11:53 +08:00
Sigbjørn Skjæret
144a4ce824 vendor : sync minja (#16500)
* sync minja.hpp

Adds Call/EndCall support, used in MiniCPM3 and MiniCPM4-MCP.

* remove spurious semicolon

* sync from ochafik/minja
2025-10-29 14:09:50 +01:00
Jeff Bolz
f549b0007d vulkan: Call ggml_vk_buffer_write_2d from ggml_vk_buffer_copy (#16793)
This lets the copy to the destination device use the host-visible
vidmem optimization.
2025-10-29 09:53:04 +01:00
Aman Gupta
9a3ea685b9 CUDA: Fix bug in topk-moe for gpt-oss (#16821)
* CUDA: Fix bug in topk-moe for gpt-oss

When using ggml_can_fuse_subgraph, the output nodes which are passed are wrong. This causes `test-backend-ops` to still fuse ndoes (because the nodes are not used elsewhere in the graph),
but it actually doesn't fuse in the actual gpt-oss

* fix for qwen3 too

* change ifndef to ifdef
2025-10-29 15:55:06 +08:00
YaelLogic
338074c383 sycl: add RMS_NORM_BACK operation support (#16808)
* sycl: add RMS_NORM_BACK operation support

* sycl: rms_norm_back: add dual reduction paths (FP64 and FP32) and savepoint before further changes

* sycl: add RMS_NORM_BACK support

Implement RMS_NORM_BACK for the SYCL backend using FP32 compensated parallel reduction. Minimal docs updates (ops.md / SYCL.csv).

* revert: restore .gitignore and tools/run/CMakeLists.txt to upstream

* revert: restore tests/CMakeLists.txt to upstream

* sycl: optimize rms_norm_back

* fix: restore SYCL.csv to correct state with RMS_NORM_BACK support

* Update ggml/src/ggml-sycl/norm.cpp

Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>

* fix: remove trailing whitespace and add missing newline (EditorConfig)

---------

Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>
2025-10-29 14:14:39 +08:00
YaelGitAccount
851553ea6b cuda: add SET operation support (#16804)
* feat(cuda): add GGML_OP_SET support

Implement CUDA kernel for SET operation with f32 support.

All tests passing (14598/14598).

* cuda(set): add I32 support; keep F32

* refactor(cuda): use ggml_cuda_cpy to unify SET operator logic and remove code duplication

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update ggml/src/ggml-cuda/set.cu

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-10-28 20:10:28 +01:00
Georgi Gerganov
85a7d8677b memory : remove KV cache size padding (#16812)
* memory : remove KV cache size padding

* cont : restore padding for n_kv tensor shape

* server : use slot context size instead of training context size

* server : simplify context limit logic
2025-10-28 20:19:44 +02:00
Georgi Gerganov
a8ca18b4b8 llama-bench : clarify benchmarked parts of the computation (#16823) 2025-10-28 19:41:43 +02:00
l3utterfly
8284efc35c initialise buffer.device in ggml_hexagon_session (#16816) 2025-10-28 08:16:20 -07:00
Sam Malayek
1c1409e131 embedding: add raw option for --embd-output-format (#16541)
* Add --embd-output-format raw for plain numeric embedding output

This new option outputs embeddings as raw space-separated floats, without JSON or 'embedding N:' prefixes. Useful for downstream vector pipelines and scripting.

* Move raw output handling into format handling section

* Move raw output handling into else-if block with other format handlers

* Use LOG instead of printf for raw embedding output

* docs: document 'raw' embedding output format in arg.cpp and README
2025-10-28 12:51:41 +02:00
Johannes Gäßler
7a0e900e36 llama: consistent ctx <-> buf order for KV cache (#16746) 2025-10-28 11:23:54 +01:00
Aldehir Rojas
280d97be96 grammar : support array references in json schema (#16792)
* grammar : support array references in json schema

* Update json-schema-to-grammar.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* grammar : improve regex when naming ref derived rules

* grammar : replace non-conformant definitions array with anyOf test case

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-10-28 09:37:52 +01:00
Chenguang Li
3479efd112 CANN: Improve device ID handling and aclnnArange checks (#16752)
* cann: improve device ID handling and aclnnArange checks

- Stop relying on CANN's internal device ID retrieval; use a global variable instead.
- Enforce stricter dimension validation in aclnnArange for better compatibility across CANN versions.

* cann: use thread local var
2025-10-28 10:54:53 +08:00
Aman Gupta
463bbf20bf CUDA: add unused vars to mmvf and mmvq (#16807) 2025-10-28 10:31:21 +08:00
tamarPal
ad8d36beff sycl: add SSM_CONV operation support (#16800)
* feat: Add SYCL backend support for SSM_CONV operator

* Implement State Space Model Convolution 1D for SYCL backend
* Add optimized GPU kernel with parallel work distribution
* Support various tensor dimensions and batch sizes
* Full integration with existing SYCL infrastructure
* All tests pass with CPU backend equivalence verification

* feat: Implement SYCL backend support for SSM_CONV operation

- Add ggml-sycl/ssm_conv.cpp and ssm_conv.hpp
- Implement SYCL kernel for state space model convolution
- Ensure numerical correctness matches CPU implementation exactly
- Add proper type checking for F32 tensors in backend support
- All test-backend-ops SSM_CONV tests pass (14490/14490)

* Perfect SSM_CONV SYCL implementation - 100% CPU parity

 Flawless numerical accuracy - matches CPU bit-for-bit
 Optimal SYCL kernel design - efficient parallel execution
 Complete tensor layout compatibility - handles all strides correctly
 Robust error handling - comprehensive assertions and validation
 All official tests pass - 14,490/14,490 backend operations verified
 Production-ready code - clean, documented, maintainable

Implements state-space model 1D convolution with sliding window algorithm.
Eliminates blocking queue.wait() for better async performance.

* Clean SSM_CONV code - remove all comments for production

Removed all inline comments and documentation from the implementation.
Clean, minimal code ready for production merge.

* fix: Final formatting corrections for CI compliance

- Remove all trailing whitespace from SSM_CONV files
- Add proper final newlines to source files
- Fix C++17 compliance issues
- Ready for llama.cpp CI validation

* sycl: fix trailing whitespace and minor safety casts in ssm_conv

* fix: Clean up duplicated content in ssm_conv.hpp header file

---------

Co-authored-by: tamarPal <tamarPal@example.com>
2025-10-28 09:50:33 +08:00
Yuri Khrustalev
c053e18a66 chat: Add LFM2 tool handling (#16763)
* Add LFM2 tool handling

* fmt

* Apply suggestion from @ykhrustalev
2025-10-27 23:54:01 +01:00
Xuan-Son Nguyen
e1ab084803 mtmd : fix idefics3 preprocessing (#16806)
* mtmd : fix idefics3 preprocessing

* disable granite test

* fix test for granite
2025-10-27 23:12:16 +01:00
Diego Devesa
5a4ff43e7d llama : disable pipeline parallelism if compute buffer allocation fails (#16748) 2025-10-27 21:51:28 +01:00
Acly
10640e31aa ggml : fix interpolate with align-corners and ne=1 (#16700)
* ggml : fix interpolate with align-corners and ne=1

* avoid division by zero if one of the spatial dimensions is 1
* cpu, cuda, opencl returned correct result anyway due to clamp
* vulkan didn't clamp for align-corners so results were broken

* fix clang warning
2025-10-27 21:50:22 +01:00
Johannes Gäßler
80d28f104c HIP: fix AMDGPU_TARGETS, update documentation (#16803) 2025-10-27 21:39:49 +01:00
Xuan-Son Nguyen
c55d53acec model : add LightOnOCR-1B model (#16764)
* model : add LightOnOCR-1B model

* add test
2025-10-27 16:02:58 +01:00
Johannes Gäßler
945501f5ea llama: fix leaked buffers for mmap + split files (#16765) 2025-10-27 09:17:31 +01:00
Aman Gupta
75cbdd3fce test-backend-ops: print failed tests at the end (#16785) 2025-10-27 09:25:10 +08:00
tamarPal
2b9bd9bf4e sycl: add ROLL operation support (#16665)
* sycl: add ROLL operation support

- Implement ggml_sycl_roll function for F32 tensors
- Add multi-axis roll operation with SYCL kernel
- Support all 4 tensor dimensions with proper shift normalization
- Add roll.cpp and roll.hpp to SYCL backend
- Update backend dispatch and supports_op for GGML_OP_ROLL
- Tests: 17662/17662 pass with identical CPU reference results

* fix: remove trailing whitespace from roll.cpp

- Fix EditorConfig violations in ggml/src/ggml-sycl/roll.cpp
- Remove trailing spaces from lines 6, 11, 28, 47, 58, 60

* ci: retrigger

* sycl: remove wait() calls from ROLL operation

* fix: editorconfig — LF endings + final newline for roll.hpp

---------

Co-authored-by: tamarPal <tamarPal@example.com>
2025-10-27 09:20:24 +08:00
shani-f
59fc1ec8e8 sycl: add REPEAT_BACK operation support (#16734)
* SYCL repeat_back v1 — add core op + switch case

* Implement repeat_back SYCL operation and minor fixes

* Update ggml/src/ggml-sycl/repeat_back.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update ggml/src/ggml-sycl/repeat_back.hpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update ggml/src/ggml-sycl/ggml-sycl.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-10-27 09:19:50 +08:00
Aman Gupta
75d33b9302 CUDA: support for weight clamp in top-k norm (#16702) 2025-10-27 09:06:16 +08:00
Acly
3470a5c891 ggml-alloc : make gallocr prefer chunks that allow memory reuse (#16788) 2025-10-26 23:19:03 +01:00
Sigbjørn Skjæret
bd562fe4f7 cuda : use fast copy when src and dst are of different type and contiguous (#16789)
* use fast copy when src and dst are contiguous and same shape

* use int64_t ne and ignore shape
2025-10-26 21:31:41 +01:00
leejet
bbac6a26b2 ggml: fix cuda kernel launch configuration for k_compute_batched_ptrs to support large batch (#16744)
* fix k_compute_batched_ptrs

* add backend ops test

* Update ggml/src/ggml-cuda/ggml-cuda.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* reduce the batch size

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2025-10-26 19:13:31 +01:00
Sigbjørn Skjæret
73a48c9790 convert : enable expert group selection for all models with it (#16691) 2025-10-26 17:21:23 +01:00
Sigbjørn Skjæret
f696428ce8 graph : add clamping to ffn_moe_weights_sum to avoid div-by-zero (#16655)
* add missing norm topk bias

* use clamping instead, update number and add comment
2025-10-26 17:20:32 +01:00
Sigbjørn Skjæret
7cce4f8158 model : set res->t_embd in SmallThinker models (#16782) 2025-10-26 16:08:52 +01:00
amirai21
8d8862829c docs : add Jamba to Text-only models list (#16778) 2025-10-26 13:01:20 +01:00
Aman Gupta
f77c13b91f CUDA: General GEMV fusion (#16715) 2025-10-26 19:28:04 +08:00
Gilad S.
3cfa9c3f12 vulkan: deduplicate Microsoft Direct3D12 devices (#16689)
* fix: deduplicate and deprioritize Microsoft Direct3D12 vulkan devices from the `vulkan-dozen` driver

* style: indent

* fix: decrease priority

* fix: switch to `||`
2025-10-26 05:37:38 +01:00
Galunid
5d195f17bc convert : handle mmproj filename/path properly (#16760)
* convert: handle mmproj model output filename properly

* remove redundant commits

* Add model_type to gguf utility

* Use mmproj- prefix instead of suffix

* Apply CISC suggestion

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-10-25 20:41:36 +02:00
Shunta Saito
226f295f4d model : set res->t_embd in PLaMo2 models (#16766) 2025-10-25 12:26:27 +02:00
Giuseppe Scrivano
f90b4a8efe vulkan: delete dead code (#16732)
ggml_vk_create_buffer_temp is not used anywhere, and it is the only
caller for ggml_vk_pool_malloc.

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
2025-10-25 10:59:54 +02:00
Jeff Bolz
8423d01931 vulkan: Optimize SSM_SCAN (#16645) 2025-10-25 07:04:12 +02:00
compilade
5cca2542ac convert : avoid dequantizing mxfp4 for GPT-OSS (#16756) 2025-10-24 20:52:00 -04:00
leejet
55945d2ef5 ggml: fix CUDA grid launch condition for large block_nums.y in binbcast (#16742)
* Fix CUDA grid launch condition for large block_nums.y

* add backend ops test

* reduce test  repetitions
2025-10-24 21:39:37 +02:00
Aman Gupta
0bcb40b48c CUDA: use CUB for arbitary size argsort (#16754) 2025-10-24 20:46:19 +08:00
Florian Badie
69e9ff0103 webui: support q URL parameter (#16728)
* webui: support q URL parameter

Fixes #16722
I’ve checked that it works with Firefox’s AI tools

* webui: apply suggestions from code review

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* chore: update webui static build

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2025-10-24 14:10:29 +02:00
Daniel Bevenius
5a91109a5d model-conversion : add trust_remote_code for orig model run [no ci] (#16751)
This commit add the trust_remote_code=True argument when loading models
using AutoConfig, AutoTokenizer, and AutoModelForCausalLM for the run
original model script.

The motivation for this is that some models require custom code to be
loaded properly, and setting trust_remote_code=True avoids a prompt
asking for user confirmation:
```console
(venv) $ make causal-run-original-model
The repository /path/to/model contains custom code which must be
executed to correctly load the model. You can inspect the repository
content at /path/to/model.

Do you wish to run the custom code? [y/N] N
```

Having this as the default seems like a safe choice as we have to clone
or download the models we convert and would be expecting to run any
custom code they have.
2025-10-24 12:02:02 +02:00
compilade
f8f071fadd convert : handle pre-quantized models (#14810)
* convert : begin handling pre-quantized models

* convert : fix conversion from FP8 for Deepseek-V3.1-Base
2025-10-23 16:31:41 -04:00
Johannes Gäßler
0bf47a1dbb server: add memory breakdown print (#16740) 2025-10-23 21:30:17 +02:00
Julien Denize
dd62dcfab9 convert : Make mistral-common dependency optional (#16738)
* Make mistral-common dependency optional

* Fix typing
2025-10-23 15:54:46 +02:00
Xuan-Son Nguyen
d0660f237a mtmd-cli : allow using --jinja (#16718)
* mtmd-cli : allow using --jinja

* support -sys

* implement chat_history

* fix clear memory

* rm -sys support, added TODO
2025-10-23 15:00:49 +02:00
Prajwal B Mehendarkar
fe6a9882ac Manually link -lbsd to resolve flock symbol on AIX (#16610) 2025-10-23 19:37:31 +08:00
Aman Gupta
061f0eff02 ggml-cuda: use passed ops instead of hardcoded ops (#16712) 2025-10-23 19:14:06 +08:00
matteo
8cf6b42d46 server : send partial stop string when <EOG> is reached (#15007) 2025-10-23 12:32:24 +03:00
137 changed files with 7350 additions and 2142 deletions

View File

@@ -40,7 +40,7 @@ jobs:
# https://github.com/ggml-org/llama.cpp/issues/11888
#- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false }
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }

View File

@@ -65,7 +65,7 @@
/ggml/src/ggml-impl.h @ggerganov @slaren
/ggml/src/ggml-metal/ @ggerganov
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
/ggml/src/ggml-hexagon/ @max-krasnyansky
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
/ggml/src/ggml-opt.cpp @JohannesGaessler
/ggml/src/ggml-quants.* @ggerganov
/ggml/src/ggml-rpc/ @rgerganov

View File

@@ -84,6 +84,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
- [x] [Jamba](https://huggingface.co/ai21labs)
- [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon)
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)

View File

@@ -3203,7 +3203,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(common_arg(
{"--parse-special"},
string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),
string_format("parse special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),
[](common_params & params) {
params.parse_special = true;
}
@@ -3248,7 +3248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg(
{"--embd-output-format"}, "FORMAT",
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix",
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
[](common_params & params, const std::string & value) {
params.embd_out = value;
}
@@ -3435,7 +3435,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
add_opt(common_arg(
{"--reasoning-format"}, "FORMAT",
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"

View File

@@ -9,8 +9,11 @@
#include <minja/chat-template.hpp>
#include <minja/minja.hpp>
#include <algorithm>
#include <cstdio>
#include <cctype>
#include <exception>
#include <functional>
#include <iostream>
#include <optional>
#include <stdexcept>
@@ -640,6 +643,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
default:
throw std::runtime_error("Unknown chat format");
}
@@ -986,6 +990,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
return data;
}
// Case-insensitive find
static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) {
auto it = std::search(
haystack.begin() + pos, haystack.end(),
needle.begin(), needle.end(),
[](char a, char b) { return std::tolower(a) == std::tolower(b); }
);
return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it);
}
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
const auto is_json_schema_provided = !inputs.json_schema.is_null();
const auto is_grammar_provided = !inputs.grammar.empty();
const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty();
// the logic requires potentially modifying the messages
auto tweaked_messages = inputs.messages;
auto replace_json_schema_marker = [](json & messages) -> bool {
static std::string marker1 = "force json schema.\n";
static std::string marker2 = "force json schema.";
if (messages.empty() || messages.at(0).at("role") != "system") {
return false;
}
std::string content = messages.at(0).at("content");
for (const auto & marker : {marker1, marker2}) {
const auto pos = ifind_string(content, marker);
if (pos != std::string::npos) {
content.replace(pos, marker.length(), "");
// inject modified content back into the messages
messages.at(0).at("content") = content;
return true;
}
}
return false;
};
// Lfm2 model does not natively work with json, but can generally understand the tools structure
//
// Example of the pytorch dialog structure:
// <|startoftext|><|im_start|>system
// List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|>
// <|im_start|>user
// What is the current status of candidate ID 12345?<|im_end|>
// <|im_start|>assistant
// <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|>
// <|im_start|>tool
// <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|>
// <|im_start|>assistant
// The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|>
//
// For the llama server compatibility with json tools semantic,
// the client can add "Follow json schema." line into the system message prompt to force the json output.
//
if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) {
// server/utils.hpp prohibits that branch for the custom grammar anyways
throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar");
} else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) {
LOG_INF("%s: Using tools to build a grammar\n", __func__);
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
schemas.push_back({
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function.at("name")},
}},
{"arguments", function.at("parameters")},
}},
{"required", json::array({"name", "arguments", "id"})},
});
});
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\"");
});
// model has no concept of tool selection mode choice,
// if the system prompt rendered correctly it will produce a tool call
// the grammar goes inside the tool call body
data.grammar_lazy = true;
data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}};
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS;
} else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) {
LOG_INF("%s: Using tools without json schema or grammar\n", __func__);
// output those tokens
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
} else if (is_json_schema_provided) {
LOG_INF("%s: Using provided json schema to build a grammar\n", __func__);
data.grammar = json_schema_to_grammar(inputs.json_schema);
} else if (is_grammar_provided) {
LOG_INF("%s: Using provided grammar\n", __func__);
data.grammar = inputs.grammar;
} else {
LOG_INF("%s: Using content relying on the template\n", __func__);
}
data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str());
return data;
}
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
@@ -2499,6 +2623,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
// Loop through all tool calls
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
builder.move_to(res->groups[0].end);
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
auto tool_calls_data = builder.consume_json();
// Consume end marker
builder.consume_spaces();
if (!builder.try_consume_regex(tool_call_end_regex)) {
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
}
// Process each tool call in the array
if (tool_calls_data.json.is_array()) {
for (const auto & tool_call : tool_calls_data.json) {
if (!tool_call.is_object()) {
throw common_chat_msg_partial_exception("Tool call must be an object");
}
if (!tool_call.contains("name")) {
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
}
std::string function_name = tool_call.at("name");
std::string arguments = "{}";
if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else if (tool_call.at("arguments").is_string()) {
arguments = tool_call.at("arguments");
}
}
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
} else {
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
}
// Consume any trailing whitespace after this tool call
builder.consume_spaces();
}
// Consume any remaining content after all tool calls
auto remaining = builder.consume_rest();
if (!string_strip(remaining).empty()) {
builder.add_content(remaining);
}
}
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
// Parse thinking tags first - this handles the main reasoning content
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
@@ -2748,6 +2937,12 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_apertus(tmpl, params);
}
// LFM2 (w/ tools)
if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos &&
src.find("]<|tool_list_end|>") != std::string::npos) {
return common_chat_params_init_lfm2(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -2926,6 +3121,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_APERTUS:
common_chat_parse_apertus(builder);
break;
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
common_chat_parse_lfm2(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}

View File

@@ -116,6 +116,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_NEMOTRON_V2,
COMMON_CHAT_FORMAT_APERTUS,
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};

View File

@@ -601,7 +601,10 @@ private:
}
std::string _resolve_ref(const std::string & ref) {
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
auto it = ref.find('#');
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
_refs_being_resolved.insert(ref);
json resolved = _refs[ref];
@@ -774,11 +777,24 @@ public:
std::vector<std::string> tokens = string_split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
if (target.is_object() && target.contains(sel)) {
target = target[sel];
} else if (target.is_array()) {
size_t sel_index;
try {
sel_index = std::stoul(sel);
} catch (const std::invalid_argument & e) {
sel_index = target.size();
}
if (sel_index >= target.size()) {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel_index];
} else {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel];
}
_refs[ref] = target;
}

View File

@@ -29,12 +29,29 @@ if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf
from gguf.vocab import MistralTokenizerType, MistralVocab
from mistral_common.tokens.tokenizers.base import TokenizerVersion
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
try:
from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
SentencePieceTokenizer,
)
_mistral_common_installed = True
_mistral_import_error_msg = ""
except ImportError:
_MISTRAL_COMMON_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
_MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
_mistral_common_installed = False
TokenizerVersion = None
Tekkenizer = None
SentencePieceTokenizer = None
_mistral_import_error_msg = (
"Mistral format requires `mistral-common` to be installed. Please run "
"`pip install mistral-common[image,audio]` to install it."
)
logger = logging.getLogger("hf-to-gguf")
@@ -73,10 +90,8 @@ class ModelBase:
use_temp_file: bool
lazy: bool
dry_run: bool
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
tensor_names: set[str] | None
model_tensors: dict[str, Callable[[], Tensor]]
gguf_writer: gguf.GGUFWriter
model_name: str | None
metadata_override: Path | None
@@ -107,6 +122,9 @@ class ModelBase:
type(self) is MmprojModel:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
if self.is_mistral_format and not _mistral_common_installed:
raise ImportError(_mistral_import_error_msg)
self.dir_model = dir_model
self.ftype = ftype
self.fname_out = fname_out
@@ -117,25 +135,8 @@ class ModelBase:
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
if remote_hf_model_id is not None:
self.is_safetensors = True
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
self.tensor_names = set(name for name in remote_tensors.keys())
for name, remote_tensor in remote_tensors.items():
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
self.get_tensors = get_remote_tensors
else:
prefix = "model" if not self.is_mistral_format else "consolidated"
self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
self.tensor_names = None
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
@@ -151,6 +152,8 @@ class ModelBase:
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
self.dequant_model()
# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@@ -172,67 +175,215 @@ class ModelBase:
return None
raise KeyError(f"could not find any of: {keys}")
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_names_from_parts: set[str] = set()
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
tensors: dict[str, Callable[[], Tensor]] = {}
if remote_hf_model_id is not None:
is_safetensors = True
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
for name, remote_tensor in remote_tensors.items():
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
return tensors
prefix = "model" if not self.is_mistral_format else "consolidated"
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
is_safetensors: bool = len(part_names) > 0
if not is_safetensors:
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
tensor_names_from_index: set[str] = set()
if not self.is_mistral_format:
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
index_name += ".index.json"
index_file = self.dir_model / index_name
if index_file.is_file():
self.tensor_names = set()
logger.info(f"gguf: loading model weight map from '{index_name}'")
with open(index_file, "r", encoding="utf-8") as f:
index: dict[str, Any] = json.load(f)
weight_map = index.get("weight_map")
if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
self.tensor_names.update(weight_map.keys())
tensor_names_from_index.update(weight_map.keys())
else:
self.tensor_names = tensor_names_from_parts
weight_map = {}
else:
self.tensor_names = tensor_names_from_parts
weight_map = {}
for part_name in self.part_names:
logger.info(f"gguf: loading model part '{part_name}'")
for part_name in part_names:
logger.info(f"gguf: indexing model part '{part_name}'")
ctx: ContextManager[Any]
if self.is_safetensors:
if is_safetensors:
from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
else:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
with ctx as model_part:
tensor_names_from_parts.update(model_part.keys())
assert model_part is not None
for name in model_part.keys():
if self.is_safetensors:
if is_safetensors:
if self.lazy:
data = model_part.get_slice(name)
data = LazyTorchTensor.from_safetensors_slice(data)
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
else:
data = model_part.get_tensor(name)
data_gen = lambda data=data: data # noqa: E731
else:
data = model_part[name]
if self.lazy:
data = LazyTorchTensor.from_eager(data)
yield name, data
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
else:
data_gen = lambda data=data: data # noqa: E731
tensors[name] = data_gen
# verify tensor name presence and identify potentially missing files
if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
missing = sorted(self.tensor_names.difference(tensor_names_from_parts))
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
if len(extra) == 0 and len(missing_files) > 0:
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
f"Missing tensors: {missing}")
if len(tensor_names_from_index) > 0:
tensor_names_from_parts = set(tensors.keys())
if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0:
missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts))
extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index))
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
if len(extra) == 0 and len(missing_files) > 0:
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
f"Missing tensors: {missing}")
else:
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
f"Missing tensors: {missing}\n"
f"Extra tensors: {extra}")
return tensors
def dequant_model(self):
tensors_to_remove: list[str] = []
new_tensors: dict[str, Callable[[], Tensor]] = {}
if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
quant_method = quant_config.get("quant_method")
def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
weight = weight.view(torch.uint8)
orig_shape = weight.shape
shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape)))))
data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift
data = data & 3
data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:]))
# The scale is inverted
return data / scale.float()
def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
scale = scale.float()
if (weight_block_size := quant_config.get("weight_block_size")):
# TODO: make sure it's a list of integers
for i, size in enumerate(weight_block_size):
scale = scale.repeat_interleave(size, i)
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
scale = scale[tuple(slice(0, size) for size in weight.shape)]
return weight.float() * scale
# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor:
bits = quant_config["bits"]
assert bits in (2, 3, 4, 8)
assert qweight.dtype == qzeros.dtype
maxq = (2 ** bits) - 1
weight = None
zeros = None
pack_dtype_bits = qweight.dtype.itemsize * 8
if bits in [2, 4, 8]:
pack_factor = pack_dtype_bits // bits
wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0)
if self.lazy:
wf = LazyTorchTensor.from_eager(wf)
zeros = torch.bitwise_right_shift(
qzeros.unsqueeze(2).expand(-1, -1, pack_factor),
wf.unsqueeze(0)
).to(torch.int16 if bits == 8 else torch.int8)
zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape)
weight = torch.bitwise_and(
torch.bitwise_right_shift(
qweight.unsqueeze(1).expand(-1, pack_factor, -1),
wf.unsqueeze(-1)
).to(torch.int16 if bits == 8 else torch.int8),
maxq
)
elif bits == 3:
raise NotImplementedError("3-bit gptq dequantization is not yet implemented")
assert weight is not None
assert zeros is not None
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
# gptq_v2 doesn't need to offset zeros
if quant_config.get("checkpoint_format", "gptq") == "gptq":
zeros += 1
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
if quant_method == "bitnet":
for name in self.model_tensors.keys():
if name.endswith(".weight_scale"):
weight_name = name.removesuffix("_scale")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
tensors_to_remove.append(name)
elif quant_method == "fp8":
for name in self.model_tensors.keys():
if name.endswith(".weight_scale_inv"):
weight_name = name.removesuffix("_scale_inv")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
tensors_to_remove.append(name)
elif quant_method == "gptq":
for name in self.model_tensors.keys():
if name.endswith(".qweight"):
base_name = name.removesuffix(".qweight")
g_idx = self.model_tensors[base_name + ".g_idx"]
qweight = self.model_tensors[base_name + ".qweight"]
qzeros = self.model_tensors[base_name + ".qzeros"]
scales = self.model_tensors[base_name + ".scales"]
new_tensors[base_name + ".weight"] = (
lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
g(), w(), z(), s()
)
)
tensors_to_remove += [
base_name + n
for n in (
".g_idx",
".qzeros",
".qweight",
".scales",
)
]
else:
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
f"Missing tensors: {missing}\n"
f"Extra tensors: {extra}")
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
for name in tensors_to_remove:
if name in self.model_tensors:
del self.model_tensors[name]
for name, value in new_tensors.items():
self.model_tensors[name] = value
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, gen in self.model_tensors.items():
yield name, gen()
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
if key not in gguf.MODEL_TENSORS[self.model_arch]:
@@ -591,6 +742,12 @@ class TextModel(ModelBase):
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
logger.info(f"gguf: experts used count = {n_experts_used}")
if (n_expert_groups := self.hparams.get("n_group")) is not None:
self.gguf_writer.add_expert_group_count(n_expert_groups)
logger.info(f"gguf: expert groups count = {n_expert_groups}")
if (n_group_used := self.hparams.get("topk_group")) is not None:
self.gguf_writer.add_expert_group_used_count(n_group_used)
logger.info(f"gguf: expert groups used count = {n_group_used}")
if (head_dim := self.hparams.get("head_dim")) is not None:
self.gguf_writer.add_key_length(head_dim)
@@ -1346,6 +1503,17 @@ class MmprojModel(ModelBase):
def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
def prepare_metadata(self, vocab_only: bool):
super().prepare_metadata(vocab_only=vocab_only)
output_type: str = self.ftype.name.partition("_")[2]
if self.fname_out.is_dir():
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=output_type, model_type=None)
self.fname_out = self.fname_out / f"mmproj-{fname_default}.gguf"
else:
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
def set_gguf_parameters(self):
self.gguf_writer.add_file_type(self.ftype)
@@ -1360,11 +1528,11 @@ class MmprojModel(ModelBase):
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))
# preprocessor config
image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
image_std = _MISTRAL_COMMON_DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
self.gguf_writer.add_vision_image_mean(image_mean)
self.gguf_writer.add_vision_image_std(image_std)
@@ -2033,6 +2201,9 @@ class LlamaModel(TextModel):
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
def _set_vocab_mistral(self):
if not _mistral_common_installed:
raise ImportError(_mistral_import_error_msg)
vocab = MistralVocab(self.dir_model)
logger.info(
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
@@ -2289,18 +2460,21 @@ class ArceeModel(LlamaModel):
)
class LlavaVisionModel(MmprojModel):
img_break_tok_id = -1
use_break_tok = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams.get("model_type") == "pixtral":
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
if self.use_break_tok:
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
elif self.is_mistral_format:
# hparams is already vision config here so norm_eps is only defined in global_config.
self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
if self.use_break_tok:
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
else:
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
logger.info(f"Image break token id: {self.img_break_tok_id}")
@@ -3678,7 +3852,43 @@ class Qwen2MoeModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
name = name.replace("language_model.", "") # InternVL
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
# handle aggregated expert tensors
# GGUF stores dimensions reversed from PyTorch, so:
# PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
# Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
# Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
mapped = f"{name}.weight" if not name.endswith(".weight") else name
# Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
# Need PyTorch: (128, 2048, 768) [reversed of GGML]
# So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
permuted = data_torch.permute(0, 2, 1).contiguous()
return [(self.map_tensor_name(mapped), permuted)]
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
split_dim = data_torch.shape[-1] // 2
gate = data_torch[..., :split_dim].contiguous()
up = data_torch[..., split_dim:].contiguous()
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
base_name = name.removesuffix(".weight")
base = base_name.rsplit('.', 1)[0]
mapped_gate = f"{base}.gate_proj.weight"
mapped_up = f"{base}.up_proj.weight"
perm_gate = gate.permute(0, 2, 1).contiguous()
perm_up = up.permute(0, 2, 1).contiguous()
return [
(self.map_tensor_name(mapped_gate), perm_gate),
(self.map_tensor_name(mapped_up), perm_up),
]
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
# skip visual tensors
return []
if name.find("experts") != -1:
@@ -3791,6 +4001,10 @@ class Qwen3Model(Qwen2Model):
return torch.stack([true_row, false_row], dim=0)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "model.vision_" in name:
# skip multimodal tensors
return []
if self.is_rerank:
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
is_real_head = not self.is_tied_embeddings and "lm_head" in name
@@ -3826,6 +4040,187 @@ class Qwen3MoeModel(Qwen2MoeModel):
super().set_vocab()
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
class Qwen3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
# Compute image_size if not present
if "image_size" not in self.hparams_vision:
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
num_pos = self.hparams_vision.get("num_position_embeddings", 2304)
patch_size = self.hparams_vision.get("patch_size", 16)
# num_position_embeddings = (image_size / patch_size) ** 2
# So image_size = sqrt(num_position_embeddings) * patch_size
image_size = int(num_pos**0.5 * patch_size)
self.hparams_vision["image_size"] = image_size
# Rename config values for compatibility
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
self.is_deepstack_layers = [False] * int(self.hparams_vision["num_hidden_layers"] or 0)
for idx in self.hparams_vision.get("deepstack_visual_indexes", []):
self.is_deepstack_layers[idx] = True
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
self.gguf_writer.add_vision_use_gelu(True)
if self.hparams_vision is not None:
merge_size = self.hparams_vision.get("spatial_merge_size")
if merge_size is not None:
self.gguf_writer.add_vision_spatial_merge_size(int(merge_size))
# Use text config's rms_norm_eps for vision attention layernorm eps
rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6)
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
if self.is_deepstack_layers:
self.gguf_writer.add_vision_is_deepstack_layers(self.is_deepstack_layers)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
assert self.hparams_vision is not None
# Skip text model tensors - they go in the text model file
if name.startswith("model.language_model.") or name.startswith("lm_head."):
return []
if name.startswith("model.visual."):
name = name.replace("model.visual.", "visual.", 1)
if name.startswith("visual.deepstack_merger_list."):
prefix, rest = name.split(".", maxsplit=3)[2:]
# prefix is the layer index, convert to absolute clip layer index!
idx = self.hparams_vision.get("deepstack_visual_indexes", [])[int(prefix)]
target = rest
tensor_type: gguf.MODEL_TENSOR
if target.startswith("norm."):
tensor_type = gguf.MODEL_TENSOR.V_DS_NORM
suffix = target.split(".", 1)[1]
elif target.startswith("linear_fc1."):
tensor_type = gguf.MODEL_TENSOR.V_DS_FC1
suffix = target.split(".", 1)[1]
elif target.startswith("linear_fc2."):
tensor_type = gguf.MODEL_TENSOR.V_DS_FC2
suffix = target.split(".", 1)[1]
else:
raise ValueError(f"Unexpected deepstack tensor: {name}")
new_name = self.format_tensor_name(tensor_type, idx, suffix=f".{suffix}")
return [(new_name, data_torch)]
if name.startswith("visual.merger."):
suffix = name.split(".", 2)[2]
if suffix.startswith("linear_fc"):
fc_idx_str, tail = suffix.split(".", 1)
fc_num = int(fc_idx_str.replace("linear_fc", ""))
# Qwen3VL has linear_fc1 and linear_fc2
# Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
if fc_num == 1:
fc_idx = 0
elif fc_num == 2:
fc_idx = 2
else:
raise ValueError(f"unexpected fc index {fc_num} in {name}")
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, fc_idx, suffix=f".{tail}")
elif suffix.startswith("norm."):
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_POST_NORM, suffix=f".{suffix.split('.', 1)[1]}")
else:
raise ValueError(f"Unexpected merger tensor: {name}")
return [(new_name, data_torch)]
if name == "visual.patch_embed.proj.weight":
# split Conv3D into Conv2Ds along temporal dimension
c1, c2, kt, _, _ = data_torch.shape
del c1, c2
if kt != 2:
raise ValueError("Current implementation only supports temporal_patch_size of 2")
return [
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]),
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
]
if name == "visual.patch_embed.proj.bias":
# Include the bias - it's used by the C++ code
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)]
if name.startswith("visual."):
return [(self.map_tensor_name(name), data_torch)]
# Fall back to parent class for other tensors
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen3VLForConditionalGeneration")
class Qwen3VLTextModel(Qwen3Model):
model_arch = gguf.MODEL_ARCH.QWEN3VL
def set_gguf_parameters(self):
super().set_gguf_parameters()
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
text_config = self.hparams.get("text_config", {})
# rope_scaling is deprecated in V5, use rope_parameters instead
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
if rope_scaling.get("mrope_section"):
# mrope_section contains [time, height, width] dimensions
mrope_section = rope_scaling["mrope_section"]
# Pad to 4 dimensions [time, height, width, extra]
while len(mrope_section) < 4:
mrope_section.append(0)
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
logger.info(f"MRoPE sections: {mrope_section[:4]}")
vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision tensors - they go in the mmproj file
if name.startswith("model.visual."):
return []
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen3VLMoeForConditionalGeneration")
class Qwen3VLMoeTextModel(Qwen3MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE
def set_gguf_parameters(self):
super().set_gguf_parameters()
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
text_config = self.hparams.get("text_config", {})
# rope_scaling is deprecated in V5, use rope_parameters instead
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
if rope_scaling.get("mrope_section"):
# mrope_section contains [time, height, width] dimensions
mrope_section = rope_scaling["mrope_section"]
# Pad to 4 dimensions [time, height, width, extra]
while len(mrope_section) < 4:
mrope_section.append(0)
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
logger.info(f"MRoPE sections: {mrope_section[:4]}")
vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision tensors - they go in the mmproj file
if name.startswith("model.visual."):
return []
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GPT2LMHeadModel")
class GPT2Model(TextModel):
model_arch = gguf.MODEL_ARCH.GPT2
@@ -4358,27 +4753,6 @@ class CodeShellModel(TextModel):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(1.0)
_has_tok_embd = False
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
new_name = self.map_tensor_name(name)
# assuming token_embd.weight is seen before output.weight
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
self.tensor_names.remove("transformer.wte.weight")
elif new_name == tok_embd_name:
self._has_tok_embd = True
return [(new_name, data_torch)]
@ModelBase.register("InternLM2ForCausalLM")
class InternLM2Model(TextModel):
@@ -8089,8 +8463,6 @@ class BailingMoeV2Model(TextModel):
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_group_count(hparams["n_group"])
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
if hparams["score_function"] == "sigmoid":
@@ -8810,6 +9182,13 @@ class SmolLM3Model(LlamaModel):
class GptOssModel(TextModel):
model_arch = gguf.MODEL_ARCH.GPT_OSS
# TODO: remove once MXFP4 is supported more generally
def dequant_model(self):
quant_config = self.hparams.get("quantization_config")
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
return
return super().dequant_model()
def transform_nibble_layout(self, tensor):
assert tensor.dtype == torch.uint8
assert tensor.shape[-1] == 16
@@ -9212,7 +9591,7 @@ class MistralModel(LlamaModel):
@staticmethod
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
assert TokenizerVersion is not None, "mistral_common is not installed"
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), (
f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}"
)
@@ -9280,6 +9659,21 @@ class PixtralModel(LlavaVisionModel):
return super().map_tensor_name(name, try_suffixes)
@ModelBase.register("LightOnOCRForConditionalGeneration")
class LightOnOCRVisionModel(LlavaVisionModel):
is_mistral_format = False
use_break_tok = False
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
name = name.replace("model.vision_encoder.", "vision_tower.")
name = name.replace("model.vision_projection.", "multi_modal_projector.")
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("KimiVLForConditionalGeneration")
class KimiVLModel(MmprojModel):
def __init__(self, *args, **kwargs):
@@ -9316,6 +9710,37 @@ class KimiVLModel(MmprojModel):
return [] # skip other tensors
@ModelBase.register("CogVLMForCausalLM")
class CogVLMVisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.COGVLM)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if not name.startswith("model.vision."):
return []
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("CogVLMForCausalLM")
class CogVLMModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.COGVLM
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# block vision tensors
if name.startswith("model.vision."):
return []
return [(self.map_tensor_name(name), data_torch)]
###### CONVERSION LOGIC ######
@@ -9589,11 +10014,9 @@ def main() -> None:
logger.info(f"Loading model: {dir_model.name}")
if args.mmproj:
if "mmproj" not in fname_out.name:
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
is_mistral_format = args.mistral_format
if is_mistral_format and not _mistral_common_installed:
raise ImportError(_mistral_import_error_msg)
disable_mistral_community_chat_template = args.disable_mistral_community_chat_template
with torch.inference_mode():

View File

@@ -261,10 +261,12 @@ You can download it from your Linux distro's package manager or from here: [ROCm
- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
```bash
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
&& cmake --build build --config Release -- -j 16
```
Note: `GPU_TARGETS` is optional, omitting it will build the code for all GPUs in the current system.
To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system.
The rocWMMA library is included by default when installing the ROCm SDK using the `rocm` meta package provided by AMD. Alternatively, if you are not using the meta package, you can install the library using the `rocwmma-dev` or `rocwmma-devel` package, depending on your system's package manager.
@@ -282,17 +284,17 @@ You can download it from your Linux distro's package manager or from here: [ROCm
```bash
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -p)" \
HIP_DEVICE_LIB_PATH=<directory-you-just-found> \
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
&& cmake --build build -- -j 16
```
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):
```bash
set PATH=%HIP_PATH%\bin;%PATH%
cmake -S . -B build -G Ninja -DAMDGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release
cmake -S . -B build -G Ninja -DGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
If necessary, adapt `GPU_TARGETS` to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`.

View File

@@ -79,7 +79,7 @@ Legend:
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ |
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ |
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |

View File

@@ -5637,25 +5637,25 @@
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","SYCL"
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","1","yes","SYCL"
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","SYCL"
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","SYCL"
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","1","yes","SYCL"
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","1","yes","SYCL"
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","SYCL"
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","1","yes","SYCL"
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","SYCL"
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","SYCL"
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","1","yes","SYCL"
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","1","yes","SYCL"
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","SYCL"
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","1","yes","SYCL"
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","SYCL"
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","SYCL"
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","1","yes","SYCL"
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","1","yes","SYCL"
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","SYCL"
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","1","yes","SYCL"
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","SYCL"
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","SYCL"
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","1","yes","SYCL"
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","SYCL"
"SYCL0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=0.000000,broadcast=0,multi_add=0","support","1","yes","SYCL"
Can't render this file because it is too large.

View File

@@ -38,6 +38,7 @@ The above command will output space-separated float values.
| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$
| 'json' | openai style |
| 'json+' | add cosine similarity matrix |
| 'raw' | plain text output |
### --embd-separator $"string"$
| $"string"$ | |

View File

@@ -70,6 +70,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
}
// plain, pipe-friendly output: one embedding per line
static void print_raw_embeddings(const float * emb,
int n_embd_count,
int n_embd,
const llama_model * model,
enum llama_pooling_type pooling_type,
int embd_normalize) {
const uint32_t n_cls_out = llama_model_n_cls_out(model);
const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK);
const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd;
for (int j = 0; j < n_embd_count; ++j) {
for (int i = 0; i < cols; ++i) {
if (embd_normalize == 0) {
LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
} else {
LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
}
}
LOG("\n");
}
}
int main(int argc, char ** argv) {
common_params params;
@@ -372,6 +395,8 @@ int main(int argc, char ** argv) {
}
if (notArray) LOG("\n}\n");
} else if (params.embd_out == "raw") {
print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
}
LOG("\n");

View File

@@ -371,8 +371,17 @@ class SchemaConverter:
raise ValueError(f'Unsupported ref {ref}')
for sel in ref.split('#')[-1].split('/')[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
if isinstance(target, list):
try:
sel_index = int(sel)
except ValueError:
raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel_index]
else:
assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
self._refs[ref] = target
else:
@@ -547,7 +556,8 @@ class SchemaConverter:
def _resolve_ref(self, ref):
ref_name = ref.split('/')[-1]
ref_fragment = ref.split('#')[-1]
ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
if ref_name not in self._rules and ref not in self._refs_being_resolved:
self._refs_being_resolved.add(ref)
resolved = self._refs[ref]

View File

@@ -138,7 +138,7 @@ if model_path is None:
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
)
config = AutoConfig.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
print("Model type: ", config.model_type)
print("Vocab size: ", config.vocab_size)
@@ -148,8 +148,8 @@ print("BOS token id: ", config.bos_token_id)
print("EOS token id: ", config.eos_token_id)
print("Loading model and tokenizer using AutoTokenizer:", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if unreleased_model_name:
model_name_lower = unreleased_model_name.lower()
@@ -171,7 +171,7 @@ if unreleased_model_name:
exit(1)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", offload_folder="offload"
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
)
for name, module in model.named_modules():

View File

@@ -242,6 +242,7 @@
#define GGML_ROPE_TYPE_NEOX 2
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24
#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
#define GGML_MROPE_SECTIONS 4

View File

@@ -226,16 +226,23 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
}
if (best_fit_block == -1) {
// no suitable block found, try the last block (this will grow a chunks size)
// no suitable block found, try the last block (this may grow a chunks size)
int64_t best_reuse = INT64_MIN;
for (int c = 0; c < alloc->n_chunks; ++c) {
struct tallocr_chunk * chunk = alloc->chunks[c];
if (chunk->n_free_blocks > 0) {
struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];
max_avail = MAX(max_avail, block->size);
if (block->size >= size) {
int64_t reuse_factor = chunk->max_size - block->offset - size;
// reuse_factor < 0 : amount of extra memory that needs to be allocated
// reuse_factor = 0 : allocated free space exactly matches tensor size
// reuse_factor > 0 : superfluous memory that will remain unused
bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse;
bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse;
if (block->size >= size && (better_reuse || better_fit)) {
best_fit_chunk = c;
best_fit_block = chunk->n_free_blocks - 1;
break;
best_reuse = reuse_factor;
}
}
}
@@ -268,7 +275,7 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
#ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, addr, tensor);
size_t cur_max = addr.offset + size;
if (cur_max > alloc->max_size[addr.chunk]) {
if (cur_max > chunk->max_size) {
// sort allocated_tensors by chunk/offset
for (int i = 0; i < 1024; i++) {
for (int j = i + 1; j < 1024; j++) {

View File

@@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
ACL_MEM_MALLOC_HUGE_FIRST));
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
theta_scale_ne, theta_scale_nb, 1);
float start = 0;
float step = 1;
@@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
theta_scale_nb, GGML_MAX_DIMS);
theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);

View File

@@ -67,19 +67,30 @@
GGML_ABORT("CANN error");
}
// Thread-local variable to record the current device of this thread.
thread_local int g_current_cann_device = -1;
/**
* @brief Sets the device to be used by CANN.
* @brief Set the CANN device to be used.
*
* @param device The device ID to set.
* @param device The target device ID to set.
*/
void ggml_cann_set_device(const int32_t device) {
int current_device = -1;
aclrtGetDevice(&current_device);
// int current_device = -1;
// Note: In some CANN versions, if no device has been set yet,
// aclrtGetDevice(&current_device) may return 0 by default.
// aclrtGetDevice(&current_device);
if (device == current_device) {
// If the current device is already the target one, no need to switch.
if (device == g_current_cann_device) {
return;
}
// Switch to the new device.
ACL_CHECK(aclrtSetDevice(device));
// Update the global device record.
g_current_cann_device = device;
}
/**

View File

@@ -1613,13 +1613,8 @@ static void ggml_compute_forward_mul_mat_id(
chunk_size = 64;
}
#if defined(__aarch64__)
// disable for ARM
const bool disable_chunking = true;
#else
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
#endif // defined(__aarch64__)
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;

View File

@@ -5474,7 +5474,7 @@ static void ggml_rope_cache_init(
}
static void ggml_mrope_cache_init(
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -5509,14 +5509,26 @@ static void ggml_mrope_cache_init(
}
float theta = theta_t;
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
if (is_imrope) { // qwen3vl apply interleaved mrope
if (sector % 3 == 1 && sector < 3 * sections[1]) {
theta = theta_h;
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
theta = theta_w;
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
theta = theta_t;
} else {
theta = theta_e;
}
} else {
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
}
}
rope_yarn(
@@ -5589,6 +5601,7 @@ static void ggml_compute_forward_rope_f32(
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
@@ -5627,7 +5640,7 @@ static void ggml_compute_forward_rope_f32(
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
@@ -5775,6 +5788,7 @@ static void ggml_compute_forward_rope_f16(
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
@@ -5813,7 +5827,7 @@ static void ggml_compute_forward_rope_f16(
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
@@ -7519,8 +7533,8 @@ static void ggml_compute_forward_upscale_f32(
float pixel_offset = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
pixel_offset = 0.0f;
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
}
for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -7909,10 +7923,10 @@ void ggml_compute_forward_argsort(
// ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16(
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const ggml_compute_params * params,
ggml_tensor * dst) {
ggml_tensor * dst,
int ir0, int ir1) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
@@ -7928,9 +7942,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;
@@ -7964,16 +7975,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
// parallelize by q rows using ggml_vec_dot_f32
// total rows in q
const int nr = neq1*neq2*neq3;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
@@ -8000,6 +8001,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
int ith = params->ith;
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
@@ -8147,6 +8150,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
}
static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;
GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
GGML_ASSERT(neq0 == DK);
GGML_ASSERT(nek0 == DK);
GGML_ASSERT(nev0 == DV);
GGML_ASSERT(neq1 == N);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// parallelize by q rows using ggml_vec_dot_f32
// total rows in q
const int64_t nr = neq1*neq2*neq3;
// rows per thread
const int ith = params->ith;
const int nth = params->nth;
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
// 4x chunks per thread
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}
if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
ggml_threadpool_chunk_set(params->threadpool, nth);
}
ggml_barrier(params->threadpool);
// The number of elements in each chunk
const int64_t dr = (nr + nchunk - 1) / nchunk;
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
while (current_chunk < nchunk) {
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
}
void ggml_compute_forward_flash_attn_ext(
const ggml_compute_params * params,
ggml_tensor * dst) {

View File

@@ -1600,6 +1600,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
return false;
}
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];
ggml_tensor * dst = op;
GGML_TENSOR_BINARY_OP_LOCALS
const void * src1_wdata = params->wdata;
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (ne11 > 3) {
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
(float *) ((char *) dst->data) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
}
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
}
}
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];
@@ -1643,31 +1669,41 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
}
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
// 4x chunks per thread
int64_t nr = ggml_nrows(op->src[0]);
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}
if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
ggml_threadpool_chunk_set(params->threadpool, nth);
}
ggml_barrier(params->threadpool);
const void * src1_wdata = params->wdata;
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
int64_t src0_start = (ith * ne01) / nth;
int64_t src0_end = ((ith + 1) * ne01) / nth;
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
if (src0_start >= src0_end) {
return;
}
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (ne11 > 3) {
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
(float *) ((char *) dst->data) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
}
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
while (current_chunk < nchunk) {
int64_t src0_start = (current_chunk * ne01) / nchunk;
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
if (src0_start >= src0_end) {
break;
}
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
}

View File

@@ -1,5 +1,81 @@
#include "argsort.cuh"
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
using namespace cub;
#endif // GGML_CUDA_USE_CUB
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
const int col = blockIdx.x * blockDim.x + threadIdx.x;
const int row = blockIdx.y;
if (col < ncols && row < nrows) {
indices[row * ncols + col] = col;
}
}
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx <= nrows) {
offsets[idx] = idx * ncols;
}
}
#ifdef GGML_CUDA_USE_CUB
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
int * temp_indices = temp_indices_alloc.get();
float * temp_keys = temp_keys_alloc.get();
int * d_offsets = offsets_alloc.get();
static const int block_size = 256;
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
const dim3 offset_grid((nrows + block_size - 1) / block_size);
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
size_t temp_storage_bytes = 0;
if (order == GGML_SORT_ORDER_ASC) {
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
stream);
} else {
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
sizeof(float) * 8, stream);
}
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get();
if (order == GGML_SORT_ORDER_ASC) {
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
stream);
} else {
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
0, sizeof(float) * 8, stream);
}
}
#endif // GGML_CUDA_USE_CUB
// Bitonic sort implementation
template<typename T>
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
T tmp = a;
@@ -11,7 +87,7 @@ template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
@@ -65,21 +141,28 @@ static int next_power_of_2(int x) {
return n;
}
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
static void argsort_f32_i32_cuda_bitonic(const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else {
GGML_ABORT("fatal error");
}
@@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
#ifdef GGML_CUDA_USE_CUB
const int ncols_pad = next_power_of_2(ncols);
const size_t shared_mem = ncols_pad * sizeof(int);
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
if (shared_mem > max_shared_mem || ncols > 1024) {
ggml_cuda_pool & pool = ctx.pool();
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
} else {
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
}
#else
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
#endif
}

View File

@@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
if (block_nums.z > 65535) {
if (block_nums.z > 65535 || block_nums.y > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));

View File

@@ -625,8 +625,11 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
// and a shift:
//
// n/d = (mulhi(n, mp) + n) >> L;
static const uint3 init_fastdiv_values(uint32_t d) {
GGML_ASSERT(d != 0);
static const uint3 init_fastdiv_values(uint64_t d_64) {
GGML_ASSERT(d_64 != 0);
GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
uint32_t d = (uint32_t)d_64;
// compute L = ceil(log2(d));
uint32_t L = 0;
@@ -1005,3 +1008,16 @@ struct ggml_backend_cuda_context {
return pool(device);
}
};
struct ggml_cuda_mm_fusion_args_host {
const ggml_tensor * x_bias = nullptr;
const ggml_tensor * gate = nullptr;
const ggml_tensor * gate_bias = nullptr;
ggml_glu_op glu_op;
};
struct ggml_cuda_mm_fusion_args_device {
const void * x_bias = nullptr;
const void * gate = nullptr;
const void * gate_bias = nullptr;
ggml_glu_op glu_op;
};

View File

@@ -1,3 +1,4 @@
#pragma once
#include "common.cuh"
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256

View File

@@ -112,6 +112,30 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
cpy_blck(cx + x_offset, cdst + dst_offset);
}
template<typename src_t, typename dst_t>
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
const src_t * x = (const src_t *) cx;
dst_t * dst = (dst_t *) cdst;
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
}
template<typename src_t, typename dst_t>
static void ggml_cpy_flt_contiguous_cuda(
const char * cx, char * cdst, const int64_t ne,
cudaStream_t stream) {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne);
}
template<typename src_t, typename dst_t>
static void ggml_cpy_flt_cuda(
const char * cx, char * cdst, const int ne,
@@ -285,7 +309,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src0_ddc = (char *) src0->data;
char * src1_ddc = (char *) src1->data;
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
@@ -296,11 +322,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -327,21 +361,45 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));

View File

@@ -50,6 +50,7 @@
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml-cuda/set.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml.h"
@@ -1957,8 +1958,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
size_t src1_stride_size = sizeof(cuda_t);
dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
const int threads_x = 16;
const int threads_y = 16;
dim3 block_dims(threads_x, threads_y);
dim3 grid_dims(
(ne13 + threads_x - 1) / threads_x,
(ne12 + threads_y - 1) / threads_y
);
k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
src0_ptr, src1_ptr, dst_t,
ptrs_src.get(), ptrs_dst.get(),
ne12, ne13,
@@ -2007,6 +2015,147 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
}
}
static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
const ggml_tensor * ffn_gate,
const ggml_tensor * glu,
const ggml_tensor * ffn_up_bias = nullptr,
const ggml_tensor * ffn_gate_bias = nullptr) {
const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;
if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {
return false;
}
const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
GGML_ASSERT(ffn_up && ffn_gate && glu);
if (!is_mul_mat && !is_mul_mat_id) {
return false;
}
const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;
if (has_bias) {
if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {
return false;
}
if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {
return false;
}
if (expected_bias_op == GGML_OP_ADD) {
const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;
const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;
if (!up_has_mul || !gate_has_mul) {
return false;
}
} else { // GGML_OP_ADD_ID
if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {
return false;
}
if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {
return false;
}
}
} else {
if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
return false;
}
}
if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
!ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
return false;
}
if (ffn_up->src[1] != ffn_gate->src[1]) {
return false;
}
if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
return false;
}
static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };
if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
return false;
}
if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
return false;
}
const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
//TODO: add support for fusion for split buffers
if (split) {
return false;
}
return true;
}
static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1];
const ggml_tensor * dst = tensor;
const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
bool use_mul_mat_vec_f =
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
//we only support fusion for ncols_dst = 1
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
return false;
}
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
return false;
}
return use_mul_mat_vec_f;
}
static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1];
const ggml_tensor * dst = tensor;
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
src0->view_src;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
// fusion is not universally faster on Pascal
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (cc <= GGML_CUDA_CC_PASCAL) {
return false;
}
//we only support fusion for ncols_dst = 1
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
return false;
}
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
return false;
}
return use_mul_mat_vec_q;
}
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
@@ -2268,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SET_ROWS:
ggml_cuda_op_set_rows(ctx, dst);
break;
case GGML_OP_SET:
ggml_cuda_op_set(ctx, dst);
break;
case GGML_OP_DUP:
ggml_cuda_dup(ctx, dst);
break;
@@ -2745,7 +2897,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
}
}
if (node->op == GGML_OP_SCALE &&
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;
}
@@ -2826,9 +2978,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
if (ops.size() == topk_moe_ops_with_norm.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx+8];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
return true;
@@ -2836,16 +2988,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
if (ops.size() == topk_moe_ops.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) {
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx+4];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
return true;
}
}
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_delayed_softmax, { node_idx + 2, node_idx + 5 })) {
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
@@ -2854,6 +3006,38 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3];
const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
return true;
}
}
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
return true;
}
}
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
@@ -2934,9 +3118,20 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
// With the use of CUDA graphs, the execution will be performed by the graph launch.
if (!use_cuda_graph || cuda_graph_update_required) {
[[maybe_unused]] int prev_i = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
#ifdef GGML_CUDA_DEBUG
const int nodes_fused = i - prev_i - 1;
prev_i = i;
if (nodes_fused > 0) {
GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
}
#endif
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
@@ -2945,17 +3140,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
if (!disable_fusion) {
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i+8];
ggml_tensor * selected_experts = cgraph->nodes[i+3];
ggml_tensor * weights = cgraph->nodes[i + 9];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_tensor * clamp = cgraph->nodes[i + 7];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
/*delayed softmax*/ false);
i += 8;
/*delayed softmax*/ false, clamp);
i += 9;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
ggml_tensor * weights = cgraph->nodes[i+4];
ggml_tensor * selected_experts = cgraph->nodes[i+3];
ggml_tensor * weights = cgraph->nodes[i + 4];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
/*delayed softmax*/ false);
i += 4;
@@ -3004,6 +3200,184 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
}
}
bool fused_mul_mat_vec = false;
int fused_node_count = 0;
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
ggml_tensor * glu = cgraph->nodes[i + 4];
ggml_tensor * gate_bias_n = glu->src[0];
ggml_tensor * up_bias_n = glu->src[1];
//we don't assume the order for {gate, up}. Instead infer it from the bias tensor
ggml_tensor * gate_n = nullptr;
ggml_tensor * up_n = nullptr;
if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
gate_n = cgraph->nodes[i];
up_n = cgraph->nodes[i + 2];
} else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
gate_n = cgraph->nodes[i + 2];
up_n = cgraph->nodes[i];
} else {
continue;
}
auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
if (op_bias == GGML_OP_ADD) {
if (bias_node->src[0] == mul_node) {
return bias_node->src[1];
}
if (bias_node->src[1] == mul_node) {
return bias_node->src[0];
}
return (ggml_tensor *) nullptr;
}
GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
GGML_ASSERT(bias_node->src[0] == mul_node);
return bias_node->src[1];
};
ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
if (!up_bias_tensor || !gate_bias_tensor) {
continue;
}
const ggml_tensor * src0 = up_n->src[0];
const ggml_tensor * src1 = up_n->src[1];
const ggml_tensor * ids = up_n->src[2];
if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate_n->src[0];
fusion_data.x_bias = up_bias_tensor;
fusion_data.gate_bias = gate_bias_tensor;
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 5;
break;
}
if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate_n->src[0];
fusion_data.x_bias = up_bias_tensor;
fusion_data.gate_bias = gate_bias_tensor;
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 5;
break;
}
} else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
ggml_tensor * glu = cgraph->nodes[i + 2];
ggml_tensor * gate = glu->src[0];
ggml_tensor * up = glu->src[1];
bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
|| (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
if (!ok) continue;
const ggml_tensor * src0 = up->src[0];
const ggml_tensor * src1 = up->src[1];
const ggml_tensor * ids = up->src[2];
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate->src[0];
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 3;
break;
}
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.gate = gate->src[0];
fusion_data.glu_op = ggml_get_glu_op(glu);
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 3;
break;
}
}
}
if (fused_mul_mat_vec) {
i += fused_node_count - 1;
continue;
}
fused_mul_mat_vec = false;
fused_node_count = 0;
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
continue;
}
ggml_tensor * mm_node = cgraph->nodes[i];
ggml_tensor * bias_node = cgraph->nodes[i + 1];
ggml_tensor * bias_tensor = nullptr;
if (bias_op == GGML_OP_ADD) {
if (bias_node->src[0] == mm_node) {
bias_tensor = bias_node->src[1];
} else if (bias_node->src[1] == mm_node) {
bias_tensor = bias_node->src[0];
} else {
continue;
}
} else {
if (bias_node->src[0] != mm_node) {
continue;
}
bias_tensor = bias_node->src[1];
}
const ggml_tensor * src0 = mm_node->src[0];
const ggml_tensor * src1 = mm_node->src[1];
const ggml_tensor * ids = mm_node->src[2];
if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
continue;
}
ggml_cuda_mm_fusion_args_host fusion_data{};
fusion_data.x_bias = bias_tensor;
if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 2;
break;
}
if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
fused_mul_mat_vec = true;
fused_node_count = 2;
break;
}
}
if (fused_mul_mat_vec) {
i += fused_node_count - 1;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
@@ -3483,6 +3857,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
op->src[0]->type == GGML_TYPE_F32 &&
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
} break;
case GGML_OP_SET:
{
const ggml_type t = op->type;
return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
t == op->src[0]->type &&
t == op->src[1]->type;
} break;
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;
@@ -3642,8 +4023,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width
#ifndef GGML_CUDA_USE_CUB
return op->src[0]->ne[0] <= 1024;
#else
return true;
#endif
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM:

View File

@@ -1,11 +1,12 @@
#include "ggml.h"
#include "common.cuh"
#include "convert.cuh"
#include "unary.cuh"
#include "mmvf.cuh"
#include "convert.cuh"
template <typename T, typename type_acc, int ncols_dst, int block_size>
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
@@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f(
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
bool use_gate = false;
bool use_bias = false;
bool use_gate_bias = false;
ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
const T * gate_x = nullptr;
const float * x_bias = nullptr;
const float * gate_bias = nullptr;
if constexpr (has_fusion) {
use_gate = fusion.gate != nullptr;
use_bias = fusion.x_bias != nullptr;
use_gate_bias = fusion.gate_bias != nullptr;
glu_op = fusion.glu_op;
if (use_gate) {
gate_x = static_cast<const T *>(fusion.gate);
}
if (use_bias) {
x_bias = static_cast<const float *>(fusion.x_bias);
}
if (use_gate_bias) {
gate_bias = static_cast<const float *>(fusion.gate_bias);
use_gate_bias = use_gate;
} else {
use_gate_bias = false;
}
}
if (use_gate) {
gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
}
if constexpr (has_fusion) {
const int channel_bias = ids ? channel_x : channel_dst;
if (use_bias) {
x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
}
if (use_gate_bias) {
gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
}
}
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
float * buf_iw = (float *) data_mmv;
float * buf_iw_gate = nullptr;
if constexpr (has_fusion) {
buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
}
if (block_size > warp_size) {
if (tid < warp_size) {
buf_iw[tid] = 0.0f;
if constexpr (has_fusion) {
if (use_gate) {
buf_iw_gate[tid] = 0.0f;
}
}
}
__syncthreads();
}
float sumf[ncols_dst] = {0.0f};
float sumf_gate[ncols_dst];
if constexpr (has_fusion) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf_gate[j] = 0.0f;
}
}
if constexpr (std::is_same_v<T, float>) {
const float2 * x2 = (const float2 *) x;
const float2 * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const float2 *) gate_x;
}
}
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = x2[col2];
float2 tmpx_gate = make_float2(0.0f, 0.0f);
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
}
}
}
}
} else if constexpr (std::is_same_v<T, half>) {
const half2 * x2 = (const half2 *) x;
const half2 * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const half2 *) gate_x;
}
}
if (std::is_same_v<type_acc, float>) {
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]);
float2 tmpx_gate = make_float2(0.0f, 0.0f);
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = __half22float2(gate_x2[col2]);
}
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
}
}
}
}
} else {
#ifdef FP16_AVAILABLE
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const half2 tmpx = x2[col2];
half2 tmpx_gate = make_half2(0.0f, 0.0f);
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
}
}
}
}
@@ -83,6 +190,15 @@ static __global__ void mul_mat_vec_f(
for (int j = 0; j < ncols_dst; ++j) {
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
}
if constexpr (has_fusion) {
if (use_gate) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
}
}
}
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
@@ -91,8 +207,20 @@ static __global__ void mul_mat_vec_f(
//TODO: add support for ggml_cuda_mad for hip_bfloat162
#if defined(GGML_USE_HIP)
const int * x2 = (const int *) x;
const int * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const int *) gate_x;
}
}
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2];
int tmpx_gate = 0;
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
@@ -100,17 +228,45 @@ static __global__ void mul_mat_vec_f(
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
}
}
}
}
#else
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
const nv_bfloat162 * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const nv_bfloat162 *) gate_x;
}
}
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const nv_bfloat162 tmpx = x2[col2];
nv_bfloat162 tmpx_gate;
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
}
}
}
}
#endif
@@ -122,13 +278,31 @@ static __global__ void mul_mat_vec_f(
for (int j = 0; j < ncols_dst; ++j) {
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
if constexpr (has_fusion) {
if (use_gate) {
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
}
}
if (block_size > warp_size) {
buf_iw[tid/warp_size] = sumf[j];
if constexpr (has_fusion) {
if (use_gate) {
buf_iw_gate[tid/warp_size] = sumf_gate[j];
}
}
__syncthreads();
if (tid < warp_size) {
sumf[j] = buf_iw[tid];
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
if constexpr (has_fusion) {
if (use_gate) {
sumf_gate[j] = buf_iw_gate[tid];
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
}
}
}
if (j < ncols_dst) {
__syncthreads();
}
@@ -139,12 +313,74 @@ static __global__ void mul_mat_vec_f(
return;
}
dst[tid*stride_col_dst + row] = sumf[tid];
float value = sumf[tid];
if constexpr (has_fusion) {
if (use_bias) {
value += x_bias[tid*stride_col_dst + row];
}
if (use_gate) {
float gate_value = sumf_gate[tid];
if (use_gate_bias) {
gate_value += gate_bias[tid*stride_col_dst + row];
}
switch (glu_op) {
case GGML_GLU_OP_SWIGLU:
value *= ggml_cuda_op_silu_single(gate_value);
break;
case GGML_GLU_OP_GEGLU:
value *= ggml_cuda_op_gelu_single(gate_value);
break;
case GGML_GLU_OP_SWIGLU_OAI: {
value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
break;
}
default:
break;
}
}
}
dst[tid*stride_col_dst + row] = value;
if constexpr (!has_fusion) {
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
}
}
template<typename T, typename type_acc, int ncols_dst, int block_size>
static void mul_mat_vec_f_switch_fusion(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
return;
}
}
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
template <typename T, typename type_acc, int ncols_dst>
static void launch_mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
void launch_mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
@@ -176,57 +412,59 @@ static void launch_mul_mat_vec_f_cuda(
}
}
const int nbytes_shared = warp_size*sizeof(float);
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 64: {
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 96: {
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 128: {
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 160: {
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 192: {
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 224: {
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
case 256: {
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break;
default: {
GGML_ABORT("fatal error");
@@ -236,7 +474,7 @@ static void launch_mul_mat_vec_f_cuda(
template <typename T, typename type_acc>
static void mul_mat_vec_f_cuda_switch_ncols_dst(
const T * x, const float * y, const int32_t * ids, float * dst,
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
@@ -246,49 +484,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
switch (ncols_dst) {
case 1:
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 2:
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 3:
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 4:
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 5:
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 6:
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 7:
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 8:
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
@@ -300,29 +538,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
template<typename T>
static void mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same_v<T, half>) {
if (prec == GGML_PREC_DEFAULT) {
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
return;
}
}
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
}
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -348,6 +588,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data;
ggml_cuda_mm_fusion_args_device fusion_local{};
if (fusion) {
GGML_ASSERT( !ids || dst->ne[2] == 1);
GGML_ASSERT( ids || dst->ne[1] == 1);
if (fusion->x_bias) {
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
fusion_local.x_bias = fusion->x_bias->data;
}
if (fusion->gate) {
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
fusion_local.gate = fusion->gate->data;
}
if (fusion->gate_bias) {
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
fusion_local.gate_bias = fusion->gate_bias->data;
}
fusion_local.glu_op = fusion->glu_op;
}
const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s1 = dst->nb[1] / ts_dst;
@@ -370,19 +634,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
@@ -409,7 +673,6 @@ void ggml_cuda_op_mul_mat_vec_f(
const int cc = ggml_cuda_info().devices[id].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
// ggml_cuda_op provides single, contiguous matrices
const int64_t stride_row = ne00;
const int64_t stride_col_y = ne10;
@@ -426,22 +689,23 @@ void ggml_cuda_op_mul_mat_vec_f(
const int64_t stride_sample_y = 0;
const int64_t stride_sample_dst = 0;
ggml_cuda_mm_fusion_args_device empty{};
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;

View File

@@ -1,6 +1,7 @@
#include "common.cuh"
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
void ggml_cuda_op_mul_mat_vec_f(
ggml_backend_cuda_context & ctx,

View File

@@ -1,5 +1,6 @@
#include "mmvq.cuh"
#include "quantize.cuh"
#include "unary.cuh"
#include "vecdotq.cuh"
#include <cstdint>
@@ -82,7 +83,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
return MMVQ_PARAMETERS_GENERIC;
}
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_dst) {
case 1:
@@ -136,11 +137,11 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1;
}
template <ggml_type type, int ncols_dst>
// tell the compiler to use as many registers as it wants, see nwarps definition below
template <ggml_type type, int ncols_dst, bool has_fusion>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
@@ -169,8 +170,54 @@ static __global__ void mul_mat_vec_q(
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst;
bool use_gate = false;
bool use_bias = false;
bool use_gate_bias = false;
const void * vgate = nullptr;
const float * x_bias = nullptr;
const float * gate_bias = nullptr;
ggml_glu_op active_glu;
if constexpr (has_fusion) {
use_gate = fusion.gate != nullptr;
use_bias = fusion.x_bias != nullptr;
use_gate_bias = fusion.gate_bias != nullptr && use_gate;
vgate = fusion.gate;
x_bias = (const float *) fusion.x_bias;
gate_bias = (const float *) fusion.gate_bias;
active_glu = fusion.glu_op;
}
const uint32_t channel_bias = ids ? channel_x : channel_dst;
float x_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
float gate_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
if constexpr (has_fusion) {
if (use_bias) {
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
// 1. Hide latency by prefetching bias and gate here
// 2. load only on threads that won't die after partial sum calculation
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
for (int j = 0; j < ncols_dst; ++j) {
x_biases[j][threadIdx.x] = x_bias[j * stride_col_dst + threadIdx.x];
}
}
}
if (use_gate_bias) {
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
for (int j = 0; j < ncols_dst; ++j) {
gate_biases[j][threadIdx.x] = gate_bias[j * stride_col_dst + threadIdx.x];
}
}
}
}
// partial sum for each thread
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
@@ -187,17 +234,35 @@ static __global__ void mul_mat_vec_q(
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp[j][i] += vec_dot_q_cuda(
vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
if constexpr (has_fusion) {
if (use_gate) {
tmp_gate[j][i] += vec_dot_q_cuda(
vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
}
}
}
}
}
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
if constexpr (!has_fusion) {
(void) tmp_shared_gate;
} else if (!use_gate) {
(void) tmp_shared_gate;
}
if (threadIdx.y > 0) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
if constexpr (has_fusion) {
if (use_gate) {
tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
}
}
}
}
}
@@ -216,14 +281,55 @@ static __global__ void mul_mat_vec_q(
#pragma unroll
for (int l = 0; l < nwarps-1; ++l) {
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
if constexpr (has_fusion) {
if (use_gate) {
tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
}
}
}
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
if constexpr (has_fusion) {
if (use_gate) {
tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
}
}
}
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
float result = tmp[j][threadIdx.x];
if constexpr (has_fusion) {
if (use_bias) {
result += x_biases[j][threadIdx.x];
}
if (use_gate) {
float gate_value = tmp_gate[j][threadIdx.x];
if (use_gate_bias) {
gate_value += gate_biases[j][threadIdx.x];
}
switch (active_glu) {
case GGML_GLU_OP_SWIGLU:
result *= ggml_cuda_op_silu_single(gate_value);
break;
case GGML_GLU_OP_GEGLU:
result *= ggml_cuda_op_gelu_single(gate_value);
break;
case GGML_GLU_OP_SWIGLU_OAI: {
result = ggml_cuda_op_swiglu_oai_single(gate_value, result);
break;
}
default:
result = result * gate_value;
break;
}
}
}
dst[j*stride_col_dst + threadIdx.x] = result;
}
}
if constexpr (!has_fusion) {
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);
}
}
static std::pair<dim3, dim3> calc_launch_params(
@@ -235,9 +341,37 @@ static std::pair<dim3, dim3> calc_launch_params(
return {block_nums, block_dims};
}
template<ggml_type type, int c_ncols_dst>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
return;
}
}
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
template <ggml_type type>
static void mul_mat_vec_q_switch_ncols_dst(
const void * vx, const void * vy, const int32_t * ids, float * dst,
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int ncols_x, const int nrows_x, const int ncols_dst,
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
@@ -256,80 +390,83 @@ static void mul_mat_vec_q_switch_ncols_dst(
const int warp_size = ggml_cuda_info().devices[device].warp_size;
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
GGML_ASSERT(!ids || ncols_dst == 1);
switch (ncols_dst) {
case 1: {
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 2: {
constexpr int c_ncols_dst = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 3: {
constexpr int c_ncols_dst = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 4: {
constexpr int c_ncols_dst = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 5: {
constexpr int c_ncols_dst = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 6: {
constexpr int c_ncols_dst = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 7: {
constexpr int c_ncols_dst = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
case 8: {
constexpr int c_ncols_dst = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break;
default:
GGML_ABORT("fatal error");
break;
}
}
GGML_UNUSED(has_fusion);
}
static void mul_mat_vec_q_switch_type(
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst,
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int ncols_x, const int nrows_x, const int ncols_dst,
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
@@ -339,143 +476,123 @@ static void mul_mat_vec_q_switch_type(
switch (type_x) {
case GGML_TYPE_Q4_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q4_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q5_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q5_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q3_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q4_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q5_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_Q6_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ2_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ2_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ3_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ1_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ1_M:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ4_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
default:
GGML_ABORT("fatal error");
@@ -484,7 +601,8 @@ static void mul_mat_vec_q_switch_type(
}
void ggml_cuda_mul_mat_vec_q(
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
@@ -508,6 +626,31 @@ void ggml_cuda_mul_mat_vec_q(
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data;
ggml_cuda_mm_fusion_args_device fusion_local{};
if (fusion) {
GGML_ASSERT( !ids || dst->ne[2] == 1);
GGML_ASSERT( ids || dst->ne[1] == 1);
if (fusion->x_bias) {
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
fusion_local.x_bias = fusion->x_bias->data;
}
if (fusion->gate) {
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
fusion_local.gate = fusion->gate->data;
}
if (fusion->gate_bias) {
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
fusion_local.gate_bias = fusion->gate_bias->data;
}
fusion_local.glu_op = fusion->glu_op;
}
// If src0 is a temporary compute buffer, clear any potential padding.
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
const size_t size_data = ggml_nbytes(src0);
@@ -549,10 +692,10 @@ void ggml_cuda_mul_mat_vec_q(
const int64_t stride_channel_y = ids ? s11 : s12;
mul_mat_vec_q_switch_type(
src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00,
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, stream);
ne03, ne3, s03, s13, s3, stream);
}
void ggml_cuda_op_mul_mat_vec_q(
@@ -578,8 +721,9 @@ void ggml_cuda_op_mul_mat_vec_q(
const int stride_row_x = ne00 / ggml_blck_size(src0->type);
const int stride_col_y = src1_padded_row_size / QK8_1;
ggml_cuda_mm_fusion_args_device fusion_local{};
mul_mat_vec_q_switch_type(
src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);

View File

@@ -3,7 +3,7 @@
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
void ggml_cuda_op_mul_mat_vec_q(
ggml_backend_cuda_context & ctx,

View File

@@ -125,7 +125,7 @@ template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
@@ -152,17 +152,29 @@ static __global__ void rope_multi(
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
} else {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -276,7 +288,7 @@ template<bool forward, typename T>
static void rope_multi_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -287,11 +299,11 @@ static void rope_multi_cuda(
if (freq_factors == nullptr) {
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} else {
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
}
}
@@ -369,6 +381,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
@@ -406,11 +419,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else {
GGML_ABORT("fatal error");
}

View File

@@ -4,30 +4,53 @@
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
// Generic quantized set_rows kernel template
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
static __global__ void k_set_rows_quant(
const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s10, const int64_t s11, const int64_t s12,
const int64_t s1, const int64_t s2, const int64_t s3) {
template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
static __global__ void k_set_rows_quant(const float * __restrict__ src0,
const idx_t * __restrict__ src1,
block_type * __restrict__ dst,
const int64_t ne_total,
const int64_t ne10,
const int64_t ne11,
const int64_t ne12,
const int64_t ne13,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t s10,
const int64_t s11,
const int64_t s12,
const int64_t s1,
const int64_t s2,
const int64_t s3,
const uint3 ne00,
const uint3 ne01,
const uint3 ne02,
const uint3 ne11_fd,
const uint3 ne12_fd) {
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
if (i >= ne_total) {
return;
}
const int64_t i_base = i * qk;
const int64_t i03 = i_base / (ne00 * ne01 * ne02);
const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
uint32_t tmp = (uint32_t) i_base;
uint2 div_mod;
const int64_t i12 = i03 % ne12;
const int64_t i11 = i02 % ne11;
div_mod = fast_div_modulo(tmp, ne00);
const int64_t i00 = div_mod.y;
tmp = div_mod.x;
div_mod = fast_div_modulo(tmp, ne01);
const int64_t i01 = div_mod.y;
tmp = div_mod.x;
div_mod = fast_div_modulo(tmp, ne02);
const int64_t i02 = div_mod.y;
const int64_t i03 = div_mod.x;
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
const int64_t i10 = i01;
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant(
quantize_func(src_block, dst_block);
GGML_UNUSED(ne10);
GGML_UNUSED(ne11);
GGML_UNUSED(ne12);
GGML_UNUSED(ne13);
}
@@ -71,40 +96,65 @@ static void set_rows_cuda_quant(
const int64_t s2 = nb2;
const int64_t s3 = nb3;
if (ne_total > 0) {
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
s01, s02, s03,
s10, s11, s12,
s1, s2, s3);
src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
ne01_fd, ne02_fd, ne11_fd, ne12_fd);
}
}
template<typename src_t, typename idx_t, typename dst_t>
static __global__ void k_set_rows(
const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s10, const int64_t s11, const int64_t s12,
const int64_t s1, const int64_t s2, const int64_t s3) {
template <typename src_t, typename idx_t, typename dst_t>
static __global__ void k_set_rows(const src_t * __restrict__ src0,
const idx_t * __restrict__ src1,
dst_t * __restrict__ dst,
const int64_t ne_total,
const int64_t ne10,
const int64_t ne11,
const int64_t ne12,
const int64_t ne13,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t s10,
const int64_t s11,
const int64_t s12,
const int64_t s1,
const int64_t s2,
const int64_t s3,
const uint3 ne00,
const uint3 ne01,
const uint3 ne02,
const uint3 ne11_fd,
const uint3 ne12_fd) {
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
if (i >= ne_total) {
return;
}
const int64_t i03 = i / (ne00 * ne01 * ne02);
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
uint32_t tmp = (uint32_t) i;
uint2 div_mod;
const int64_t i12 = i03 % ne12;
const int64_t i11 = i02 % ne11;
div_mod = fast_div_modulo(tmp, ne00);
const int64_t i00 = div_mod.y;
tmp = div_mod.x;
div_mod = fast_div_modulo(tmp, ne01);
const int64_t i01 = div_mod.y;
tmp = div_mod.x;
div_mod = fast_div_modulo(tmp, ne02);
const int64_t i02 = div_mod.y;
const int64_t i03 = div_mod.x;
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
const int64_t i10 = i01;
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
@@ -115,6 +165,8 @@ static __global__ void k_set_rows(
dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
GGML_UNUSED(ne10);
GGML_UNUSED(ne11);
GGML_UNUSED(ne12);
GGML_UNUSED(ne13);
}
@@ -144,14 +196,16 @@ static void set_rows_cuda(
const int64_t s2 = nb2/sizeof(dst_t);
const int64_t s3 = nb3/sizeof(dst_t);
if (ne_total > 0) {
k_set_rows<<<grid_size, block_size, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
s01, s02, s03,
s10, s11, s12,
s1, s2, s3);
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
ne11_fd, ne12_fd);
}
}

39
ggml/src/ggml-cuda/set.cu Normal file
View File

@@ -0,0 +1,39 @@
#include "set.cuh"
#include "cpy.cuh"
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
GGML_ASSERT(src1->type == src0->type);
GGML_ASSERT(dst ->type == src0->type);
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
const size_t nb1 = ((int32_t *) dst->op_params)[0];
const size_t nb2 = ((int32_t *) dst->op_params)[1];
const size_t nb3 = ((int32_t *) dst->op_params)[2];
const size_t offset = ((int32_t *) dst->op_params)[3];
const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
if (!inplace) {
ggml_cuda_cpy(ctx, src0, dst);
}
ggml_tensor dst_view = *dst;
dst_view.data = (void *)((char *)dst->data + offset);
dst_view.ne[0] = src1->ne[0];
dst_view.ne[1] = src1->ne[1];
dst_view.ne[2] = src1->ne[2];
dst_view.ne[3] = src1->ne[3];
dst_view.nb[0] = ggml_element_size(dst);
dst_view.nb[1] = nb1;
dst_view.nb[2] = nb2;
dst_view.nb[3] = nb3;
ggml_cuda_cpy(ctx, src1, &dst_view);
}

View File

@@ -0,0 +1,7 @@
#pragma once
#include "common.cuh"
#define CUDA_SET_BLOCK_SIZE 256
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -2,6 +2,7 @@
#include "ggml.h"
#include "topk-moe.cuh"
#include <cmath>
#include <initializer_list>
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
@@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float * weights,
int32_t * ids,
const int n_rows,
const int n_expert_used) {
const int n_expert_used,
const float clamp_val) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= n_rows) {
return;
@@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
if constexpr (with_norm) {
wt_sum = warp_reduce_sum(wt_sum);
wt_sum = max(wt_sum, clamp_val);
const float inv_sum = 1.0f / wt_sum;
for (int i = 0; i < experts_per_thread; i++) {
@@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
weights[idx] = output_weights[i];
}
}
if (!with_norm) {
GGML_UNUSED(clamp_val);
}
}
template <bool with_norm, bool delayed_softmax = false>
@@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
int32_t * ids,
const int n_rows,
const int n_expert,
const int n_expert_used) {
const int n_expert_used,
const float clamp_val) {
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
@@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
switch (n_expert) {
case 1:
topk_moe_cuda<1, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 2:
topk_moe_cuda<2, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 4:
topk_moe_cuda<4, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 8:
topk_moe_cuda<8, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 16:
topk_moe_cuda<16, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 32:
topk_moe_cuda<32, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 64:
topk_moe_cuda<64, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 128:
topk_moe_cuda<128, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 256:
topk_moe_cuda<256, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
case 512:
topk_moe_cuda<512, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break;
default:
GGML_ASSERT(false && "fatal error");
@@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax) {
const bool delayed_softmax,
ggml_tensor * clamp) {
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const int n_expert_used = weights->ne[1];
float clamp_val = -INFINITY;
if (with_norm) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
if (clamp) {
clamp_val = ggml_get_op_params_f32(clamp, 0);
}
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
} else {
GGML_ASSERT(clamp == nullptr);
if (delayed_softmax) {
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
} else {
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
}
}
}
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
float scale = 1.0f;
float max_bias = 0.0f;
@@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return false;
}
if (clamp) {
if (clamp->op != GGML_OP_CLAMP) {
return false;
}
float max_val = ggml_get_op_params_f32(clamp, 1);
if (max_val != INFINITY) {
return false;
}
}
return true;
}
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };

View File

@@ -8,8 +8,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax = false);
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

View File

@@ -18,10 +18,7 @@ static __device__ __forceinline__ float op_step(float x) {
}
static __device__ __forceinline__ float op_gelu(float x) {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
return ggml_cuda_op_gelu_single(x);
}
static __device__ __forceinline__ float op_gelu_erf(float x) {
@@ -37,7 +34,7 @@ static __device__ __forceinline__ float op_gelu_quick(float x) {
}
static __device__ __forceinline__ float op_silu(float x) {
return x / (1.0f + expf(-x));
return ggml_cuda_op_silu_single(x);
}
static __device__ __forceinline__ float op_tanh(float x) {
@@ -317,13 +314,8 @@ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, cons
float xi = x[j0];
float gi = g[j1];
xi = fminf(xi, limit);
gi = fmaxf(fminf(gi, limit), -limit);
float out_glu = xi / (1.0f + expf(-xi * alpha));
out_glu = out_glu * (1.0f + gi);
dst[i] = out_glu;
dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit);
}
template <typename T>

View File

@@ -1,3 +1,4 @@
#pragma once
#include "common.cuh"
#define CUDA_NEG_BLOCK_SIZE 256
@@ -75,3 +76,23 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
return x / (1.0f + expf(-x));
}
__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x)));
}
__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
x = fminf(x, limit);
g = fmaxf(fminf(g, limit), -limit);
float out_glu = x / (1.0f + expf(-x * alpha));
out_glu = out_glu * (1.0f + g);
return out_glu;
}

View File

@@ -126,8 +126,8 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
float pixel_offset = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
pixel_offset = 0.0f;
}
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],

View File

@@ -211,12 +211,15 @@ static inline void hex_format_op_names(char * str, const struct ggml_tensor * t)
// ** backend sessions
struct ggml_hexagon_session {
ggml_hexagon_session(int dev_id) noexcept(false);
ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false);
~ggml_hexagon_session() noexcept(true);
void allocate(int dev_id) noexcept(false);
void release() noexcept(true);
void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false);
void flush();
ggml_backend_buffer_type buffer_type;
ggml_backend_buffer_type repack_buffer_type;
@@ -237,15 +240,37 @@ struct ggml_hexagon_session {
uint32_t prof_pkts;
};
// Packet callback
static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * context) {
auto sess = static_cast<ggml_hexagon_session *>(context);
void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) {
// Bump pending flag (cleared in the session::flush once we get the responce)
this->op_pending++; // atomic inc
int err = dspqueue_write(this->queue,
0, // flags - the framework will autoset this
n_bufs, // number of buffers
bufs, // buffer references
sizeof(req),
(const uint8_t *) &req, // Message
1000000 // Timeout
);
if (err != 0) {
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
}
if (sync) {
flush();
}
}
// Flush HTP response queue i.e wait for all outstanding requests to complete
void ggml_hexagon_session::flush() {
dspqueue_t q = this->queue;
// Repeatedly read packets from the queue until it's empty. We don't
// necessarily get a separate callback for each packet, and new packets
// may arrive while we're processing the previous one.
while (1) {
while (this->op_pending) {
struct htp_general_rsp rsp;
uint32_t rsp_size;
uint32_t flags;
@@ -253,22 +278,23 @@ static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * contex
struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
uint32_t n_bufs;
// Read packet from queue
int err = dspqueue_read_noblock(queue, &flags,
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
&n_bufs, // Number of buffer references
bufs, // Buffer references
sizeof(rsp), // Max message length
&rsp_size, // Message length
(uint8_t *) &rsp);
// Read response packet from queue
int err = dspqueue_read(q, &flags,
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
&n_bufs, // Number of buffer references
bufs, // Buffer references
sizeof(rsp), // Max message length
&rsp_size, // Message length
(uint8_t *) &rsp,
1000000); // Timeout
if (err == AEE_EWOULDBLOCK) {
// Consumed all packets available for now
return;
if (err == AEE_EEXPIRED) {
// TODO: might need to bail out if the HTP is stuck on something
continue;
}
if (err != 0) {
GGML_ABORT("ggml-hex: dspqueue_read_noblock failed: 0x%08x\n", (unsigned) err);
GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err);
}
// Basic sanity checks
@@ -281,21 +307,15 @@ static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * contex
// TODO: handle errors
}
// FIXME: update profiling implementation
sess->prof_usecs = rsp.prof_usecs;
sess->prof_cycles = rsp.prof_cycles;
sess->prof_pkts = rsp.prof_pkts;
// TODO: update profiling implementation, currently only works for opt_opsync mode
this->prof_usecs = rsp.prof_usecs;
this->prof_cycles = rsp.prof_cycles;
this->prof_pkts = rsp.prof_pkts;
sess->op_pending--; // atomic dec
this->op_pending--; // atomic dec
}
}
// Error callback - simply terminates with an error. Used where we don't
// expect errors.
[[noreturn]] static void htp_error_callback(dspqueue_t queue, AEEResult error, void * context) {
GGML_ABORT("ggml-hex: dspcall general error 0x%x: for queue %p\n", error, (void *) queue);
}
// ** backend buffers
struct ggml_backend_hexagon_buffer_type_context {
@@ -656,6 +676,15 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size)
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
// Ensure we don't try to read more data than is available in the source buffer 'data'
// or write more than the tensor can hold.
const size_t total_tensor_size = (size_t)nrows * row_size;
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
// Calculate how many full rows and how many remaining bytes we need to process.
const int64_t n_full_rows = n_bytes_to_copy / row_size;
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
void * buf_pd = ggml_aligned_malloc(row_size_pd);
GGML_ASSERT(buf_pd != NULL);
@@ -667,7 +696,8 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size)
init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
for (int64_t i = 0; i < nrows; i++) {
// 1. Process all the full rows
for (int64_t i = 0; i < n_full_rows; i++) {
const uint8_t * src = (const uint8_t *) data + (i * row_size);
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
@@ -676,6 +706,25 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size)
memcpy(dst, buf_rp, row_size);
}
// 2. Process the final, potentially partial, row
if (n_rem_bytes > 0) {
const int64_t i = n_full_rows;
const uint8_t * src = (const uint8_t *) data + (i * row_size);
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
// re-init the row because we are potentially copying a partial row
init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);
// Copy only the remaining bytes from the source.
memcpy(buf_pd, src, n_rem_bytes);
// Repack the entire buffer
repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
// Write only the corresponding remaining bytes to the destination tensor.
memcpy(dst, buf_rp, n_rem_bytes);
}
ggml_aligned_free(buf_pd, row_size_pd);
ggml_aligned_free(buf_rp, row_size_rp);
}
@@ -688,6 +737,14 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
// Ensure we don't try to copy more data than the tensor actually contains.
const size_t total_tensor_size = (size_t)nrows * row_size;
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
// Calculate how many full rows and how many remaining bytes we need to process.
const int64_t n_full_rows = n_bytes_to_copy / row_size;
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
void * buf_pd = ggml_aligned_malloc(row_size_pd);
GGML_ASSERT(buf_pd != NULL);
@@ -699,7 +756,8 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
for (int64_t i = 0; i < nrows; i++) {
// 1. Process all the full rows
for (int64_t i = 0; i < n_full_rows; i++) {
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
uint8_t * dst = (uint8_t *) data + (i * row_size);
@@ -708,6 +766,20 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
memcpy(dst, buf_rp, row_size);
}
// 2. Process the final, potentially partial, row
if (n_rem_bytes > 0) {
const int64_t i = n_full_rows;
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
uint8_t * dst = (uint8_t *) data + (i * row_size);
// We still need to read and unpack the entire source row because quantization is block-based.
memcpy(buf_pd, src, row_size);
unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
// But we only copy the remaining number of bytes to the destination.
memcpy(dst, buf_rp, n_rem_bytes);
}
ggml_aligned_free(buf_pd, row_size_pd);
ggml_aligned_free(buf_rp, row_size_rp);
}
@@ -930,6 +1002,15 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size)
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
// Ensure we don't try to read more data than is available in the source buffer 'data'
// or write more than the tensor can hold.
const size_t total_tensor_size = (size_t)nrows * row_size;
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
// Calculate how many full rows and how many remaining bytes we need to process.
const int64_t n_full_rows = n_bytes_to_copy / row_size;
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
void * buf_pd = ggml_aligned_malloc(row_size_pd);
GGML_ASSERT(buf_pd != NULL);
@@ -941,7 +1022,8 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size)
init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
for (int64_t i = 0; i < nrows; i++) {
// 1. Process all the full rows
for (int64_t i = 0; i < n_full_rows; i++) {
const uint8_t * src = (const uint8_t *) data + (i * row_size);
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
@@ -950,6 +1032,25 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size)
memcpy(dst, buf_rp, row_size);
}
// 2. Process the final, potentially partial, row
if (n_rem_bytes > 0) {
const int64_t i = n_full_rows;
const uint8_t * src = (const uint8_t *) data + (i * row_size);
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
// re-init the row because we are potentially copying a partial row
init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);
// Copy only the remaining bytes from the source.
memcpy(buf_pd, src, n_rem_bytes);
// Repack the entire buffer
repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
// Write only the corresponding remaining bytes to the destination tensor.
memcpy(dst, buf_rp, n_rem_bytes);
}
ggml_aligned_free(buf_pd, row_size_pd);
ggml_aligned_free(buf_rp, row_size_rp);
}
@@ -962,6 +1063,14 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size)
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
// Ensure we don't try to copy more data than the tensor actually contains.
const size_t total_tensor_size = (size_t)nrows * row_size;
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
// Calculate how many full rows and how many remaining bytes we need to process.
const int64_t n_full_rows = n_bytes_to_copy / row_size;
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
void * buf_pd = ggml_aligned_malloc(row_size_pd);
GGML_ASSERT(buf_pd != NULL);
@@ -973,7 +1082,8 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size)
memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
for (int64_t i = 0; i < nrows; i++) {
// 1. Process all the full rows
for (int64_t i = 0; i < n_full_rows; i++) {
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
uint8_t * dst = (uint8_t *) data + (i * row_size);
@@ -982,6 +1092,20 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size)
memcpy(dst, buf_rp, row_size);
}
// 2. Process the final, potentially partial, row
if (n_rem_bytes > 0) {
const int64_t i = n_full_rows;
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
uint8_t * dst = (uint8_t *) data + (i * row_size);
// We still need to read and unpack the entire source row because quantization is block-based.
memcpy(buf_pd, src, row_size);
unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
// But we only copy the remaining number of bytes to the destination.
memcpy(dst, buf_rp, n_rem_bytes);
}
ggml_aligned_free(buf_pd, row_size_pd);
ggml_aligned_free(buf_rp, row_size_rp);
}
@@ -1229,6 +1353,15 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
// Ensure we don't try to read more data than is available in the source buffer 'data'
// or write more than the tensor can hold.
const size_t total_tensor_size = (size_t)nrows * row_size;
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
// Calculate how many full rows and how many remaining bytes we need to process.
const int64_t n_full_rows = n_bytes_to_copy / row_size;
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
void * buf_pd = ggml_aligned_malloc(row_size_pd);
GGML_ASSERT(buf_pd != NULL);
@@ -1240,7 +1373,8 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si
init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
for (int64_t i = 0; i < nrows; i++) {
// 1. Process all the full rows
for (int64_t i = 0; i < n_full_rows; i++) {
const uint8_t * src = (const uint8_t *) data + (i * row_size);
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
@@ -1249,6 +1383,25 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si
memcpy(dst, buf_rp, row_size);
}
// 2. Process the final, potentially partial, row
if (n_rem_bytes > 0) {
const int64_t i = n_full_rows;
const uint8_t * src = (const uint8_t *) data + (i * row_size);
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
// re-init the row because we are potentially copying a partial row
init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);
// Copy only the remaining bytes from the source.
memcpy(buf_pd, src, n_rem_bytes);
// Repack the entire buffer (partial data + zero padding).
repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
// Write only the corresponding remaining bytes to the destination tensor.
memcpy(dst, buf_rp, n_rem_bytes);
}
ggml_aligned_free(buf_pd, row_size_pd);
ggml_aligned_free(buf_rp, row_size_rp);
}
@@ -1261,6 +1414,14 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
// Ensure we don't try to copy more data than the tensor actually contains.
const size_t total_tensor_size = (size_t)nrows * row_size;
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
// Calculate how many full rows and how many remaining bytes we need to process.
const int64_t n_full_rows = n_bytes_to_copy / row_size;
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
void * buf_pd = ggml_aligned_malloc(row_size_pd);
GGML_ASSERT(buf_pd != NULL);
@@ -1272,7 +1433,8 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si
memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
for (int64_t i = 0; i < nrows; i++) {
// 1. Process all the full rows
for (int64_t i = 0; i < n_full_rows; i++) {
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
uint8_t * dst = (uint8_t *) data + (i * row_size);
@@ -1281,6 +1443,20 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si
memcpy(dst, buf_rp, row_size);
}
// 2. Process the final, potentially partial, row
if (n_rem_bytes > 0) {
const int64_t i = n_full_rows;
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
uint8_t * dst = (uint8_t *) data + (i * row_size);
// We still need to read and unpack the entire source row because the format is block-based.
memcpy(buf_pd, src, row_size);
unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
// But we only copy the remaining number of bytes to the destination to respect the size limit.
memcpy(dst, buf_rp, n_rem_bytes);
}
ggml_aligned_free(buf_pd, row_size_pd);
ggml_aligned_free(buf_rp, row_size_rp);
}
@@ -1299,19 +1475,19 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
switch (tensor->type) {
case GGML_TYPE_Q4_0:
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
repack_q4_0_q4x4x2(tensor, data, size);
break;
case GGML_TYPE_Q8_0:
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
repack_q8_0_q8x4x2(tensor, data, size);
break;
case GGML_TYPE_MXFP4:
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
repack_mxfp4_mxfp4x4x2(tensor, data, size);
break;
@@ -1335,19 +1511,19 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
switch (tensor->type) {
case GGML_TYPE_Q4_0:
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
repack_q4x4x2_q4_0(data, tensor, size);
break;
case GGML_TYPE_Q8_0:
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
repack_q8x4x2_q8_0(data, tensor, size);
break;
case GGML_TYPE_MXFP4:
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
repack_mxfp4x4x2_mxfp4(data, tensor, size);
break;
@@ -1564,7 +1740,8 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
0, // Flags
128 * 1024, // Request queue size (in bytes)
64 * 1024, // Response queue size (in bytes)
htp_packet_callback, htp_error_callback,
nullptr, // Read packet callback (we handle reads explicitly)
nullptr, // Error callback (we handle errors during reads)
(void *) this, // Callback context
&queue);
if (err != 0) {
@@ -1631,10 +1808,13 @@ void ggml_hexagon_session::release() noexcept(true) {
}
}
ggml_hexagon_session::ggml_hexagon_session(int dev_id) noexcept(false) {
ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) {
buffer_type.context = nullptr;
repack_buffer_type.context = nullptr;
buffer_type.device = dev;
repack_buffer_type.device = dev;
try {
allocate(dev_id);
@@ -2202,7 +2382,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
bufs[0].ptr = src0->data;
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
bufs[0].size = ggml_nbytes(src0);
bufs[0].flags = DSPQUEUE_BUFFER_FLAG_REF;
bufs[0].flags = 0;
// Second buffer Input Activations. This is a buffer that the CPU
// writes and the DSP reads, so we'll need to flush CPU caches and
@@ -2212,8 +2392,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
bufs[1].ptr = src1->data;
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
bufs[1].size = ggml_nbytes(src1);
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
// Third buffer Output Activations. We'll handle DSP
@@ -2224,7 +2403,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
bufs[2].ptr = dst->data;
bufs[2].offset = (uint8_t *) dst->data - dst_buf->base;
bufs[2].size = ggml_nbytes(dst);
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
// Primary DSP session from the src0 (normally weight) tensor
auto sess = src0_buf->sess;
@@ -2252,27 +2431,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
}
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
// Bump pending flag (cleared in the callback once we get the responce)
sess->op_pending++; // atomic inc
int err = dspqueue_write(sess->queue,
0, // flags - the framework will autoset this
3, // number of buffers
bufs, // buffer references
sizeof(req),
(const uint8_t *) &req, // Message
1000000 // Timeout
);
if (err != 0) {
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
}
}
if (opt_opsync) {
while (sess->op_pending) {
;
}
sess->enqueue(req, bufs, 3, opt_opsync);
}
t2 = ggml_time_us();
@@ -2328,7 +2487,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
bufs[0].ptr = src0->data;
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
bufs[0].size = ggml_nbytes(src0);
bufs[0].flags = DSPQUEUE_BUFFER_FLAG_REF;
bufs[0].flags = 0;
// Second buffer Input Activations. This is a buffer that the CPU
// writes and the DSP reads, so we'll need to flush CPU caches and
@@ -2338,8 +2497,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
bufs[1].ptr = src1->data;
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
bufs[1].size = ggml_nbytes(src1);
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
// Third buffer expert IDs. This is a buffer that the CPU
@@ -2350,8 +2508,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
bufs[2].ptr = src2->data;
bufs[2].offset = (uint8_t *) src2->data - src2_buf->base;
bufs[2].size = ggml_nbytes(src2);
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
// Forth buffer Output Activations. We'll handle DSP
@@ -2362,7 +2519,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
bufs[3].ptr = dst->data;
bufs[3].offset = (uint8_t *) dst->data - dst_buf->base;
bufs[3].size = ggml_nbytes(dst);
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
// Primary DSP session from the src0 (normally weight) tensor
auto sess = src0_buf->sess;
@@ -2391,27 +2548,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
}
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
// Bump pending flag (cleared in the callback once we get the responce)
sess->op_pending++; // atomic inc
int err = dspqueue_write(sess->queue,
0, // flags - the framework will autoset this
4, // number of buffers
bufs, // buffer references
sizeof(req),
(const uint8_t *) &req, // Message
1000000 // Timeout
);
if (err != 0) {
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
}
}
if (opt_opsync) {
while (sess->op_pending) {
;
}
sess->enqueue(req, bufs, 4, opt_opsync);
}
t2 = ggml_time_us();
@@ -2484,8 +2621,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
bufs[0].ptr = src0->data;
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
bufs[0].size = ggml_nbytes(src0);
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
// Second buffer = Second Operand of Binary op
@@ -2497,8 +2633,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
bufs[1].ptr = src1->data;
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
bufs[1].size = ggml_nbytes(src1);
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
// Third buffer = Output Activations. We'll handle DSP
@@ -2509,7 +2644,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
bufs[2].ptr = dst->data;
bufs[2].offset = (uint8_t *) dst->data - dst_buf->base;
bufs[2].size = ggml_nbytes(dst);
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
// Primary DSP session from the src0 tensor
ggml_hexagon_session * sess = src0_buf->sess;
@@ -2537,26 +2672,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
}
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
// Bump pending flag (cleared in the callback once we get the responce)
sess->op_pending++; // atomic inc
int err = dspqueue_write(sess->queue,
0, // flags - the framework will autoset this
3, // number of buffers
bufs, // buffer references
sizeof(req),
(const uint8_t *) &req, // Message
1000000); // Timeout
if (0 != err) {
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
}
}
if (opt_opsync) {
while (sess->op_pending) {
;
}
sess->enqueue(req, bufs, 3, opt_opsync);
}
t2 = ggml_time_us();
@@ -2621,8 +2737,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
bufs[0].ptr = src0->data;
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
bufs[0].size = ggml_nbytes(src0);
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
// Second buffer = experts bias
@@ -2630,8 +2745,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
bufs[1].ptr = src1->data;
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
bufs[1].size = ggml_nbytes(src1);
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
// Third buffer = activated experts
@@ -2639,8 +2753,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
bufs[2].ptr = src2->data;
bufs[2].offset = (uint8_t *) src2->data - src2_buf->base;
bufs[2].size = ggml_nbytes(src2);
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
// Forth buffer = output activations
@@ -2648,7 +2761,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
bufs[3].ptr = dst->data;
bufs[3].offset = (uint8_t *) dst->data - dst_buf->base;
bufs[3].size = ggml_nbytes(dst);
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
// Primary DSP session from the src0 tensor
ggml_hexagon_session * sess = src0_buf->sess;
@@ -2678,26 +2791,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
}
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
// Bump pending flag (cleared in the callback once we get the responce)
sess->op_pending++; // atomic inc
int err = dspqueue_write(sess->queue,
0, // flags - the framework will autoset this
4, // number of buffers
bufs, // buffer references
sizeof(req),
(const uint8_t *) &req, // Message
1000000); // Timeout
if (0 != err) {
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
}
}
if (opt_opsync) {
while (sess->op_pending) {
;
}
sess->enqueue(req, bufs, 4, opt_opsync);
}
t2 = ggml_time_us();
@@ -2795,8 +2889,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
bufs[n_bufs].ptr = src0->data;
bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base;
bufs[n_bufs].size = ggml_nbytes(src0);
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
++n_bufs;
@@ -2811,8 +2904,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
bufs[n_bufs].ptr = src1->data;
bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base;
bufs[n_bufs].size = ggml_nbytes(src1);
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
++n_bufs;
}
@@ -2827,7 +2919,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
bufs[n_bufs].ptr = dst->data;
bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base;
bufs[n_bufs].size = ggml_nbytes(dst);
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
++n_bufs;
// Primary DSP session from the src0 tensor
@@ -2860,26 +2952,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
}
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
// Bump pending flag (cleared in the callback once we get the responce)
sess->op_pending++; // atomic inc
int err = dspqueue_write(sess->queue,
0, // flags - the framework will autoset this
n_bufs, // number of buffers
bufs, // buffer references
sizeof(req),
(const uint8_t *) &req, // Message
1000000); // Timeout
if (0 != err) {
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
}
}
if (opt_opsync) {
while (sess->op_pending) {
;
}
sess->enqueue(req, bufs, n_bufs, opt_opsync);
}
t2 = ggml_time_us();
@@ -2953,8 +3026,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
bufs[n_bufs].ptr = src0->data;
bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base;
bufs[n_bufs].size = ggml_nbytes(src0);
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
++n_bufs;
@@ -2968,8 +3040,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
bufs[n_bufs].ptr = src1->data;
bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base;
bufs[n_bufs].size = ggml_nbytes(src1);
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
++n_bufs;
@@ -2984,8 +3055,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
bufs[n_bufs].ptr = src2->data;
bufs[n_bufs].offset = (uint8_t *) src2->data - src2_buf->base;
bufs[n_bufs].size = ggml_nbytes(src2);
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
++n_bufs;
}
@@ -3000,7 +3070,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
bufs[n_bufs].ptr = dst->data;
bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base;
bufs[n_bufs].size = ggml_nbytes(dst);
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
++n_bufs;
// Primary DSP session from the src0 tensor
@@ -3033,26 +3103,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
}
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
// Bump pending flag (cleared in the callback once we get the responce)
sess->op_pending++; // atomic inc
int err = dspqueue_write(sess->queue,
0, // flags - the framework will autoset this
n_bufs, // number of buffers
bufs, // buffer references
sizeof(req),
(const uint8_t *) &req, // Message
1000000); // Timeout
if (0 != err) {
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
}
}
if (opt_opsync) {
while (sess->op_pending) {
;
}
sess->enqueue(req, bufs, n_bufs, opt_opsync);
}
t2 = ggml_time_us();
@@ -3197,9 +3248,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
}
// Wait until all pending ops complete
while (sess->op_pending) {
;
}
sess->flush();
return GGML_STATUS_SUCCESS;
}
@@ -3210,9 +3259,7 @@ static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) {
HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str());
// Wait until all pending ops complete
while (sess->op_pending) {
;
}
sess->flush();
}
struct node_info {
@@ -3628,7 +3675,7 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
devices[i].iface = ggml_backend_hexagon_device_i;
devices[i].reg = reg;
try {
devices[i].context = new ggml_hexagon_session(i);
devices[i].context = new ggml_hexagon_session(i, &devices[i]);
} catch (std::exception const &exc) {
GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i);
devices[i].context = nullptr;

View File

@@ -395,28 +395,14 @@ static void proc_matmul_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
size_t n_bufs) {
// Prep response buffer structs (needed for error responses, etc)
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
memset(rsp_bufs, 0, sizeof(rsp_bufs));
rsp_bufs[0].fd = bufs[0].fd;
rsp_bufs[0].ptr = bufs[0].ptr;
rsp_bufs[0].size = bufs[0].size;
rsp_bufs[0].offset = bufs[0].offset;
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
rsp_bufs[1].fd = bufs[1].fd;
rsp_bufs[1].ptr = bufs[1].ptr;
rsp_bufs[1].size = bufs[1].size;
rsp_bufs[1].offset = bufs[1].offset;
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[2].fd = bufs[2].fd;
rsp_bufs[2].ptr = bufs[2].ptr;
rsp_bufs[2].size = bufs[2].size;
rsp_bufs[2].offset = bufs[2].offset;
rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
@@ -444,41 +430,21 @@ static void proc_matmul_req(struct htp_context * ctx,
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_matmul_id_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
size_t n_bufs) {
// Prep response buffer structs (needed for error responses, etc)
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
memset(rsp_bufs, 0, sizeof(rsp_bufs));
rsp_bufs[0].fd = bufs[0].fd;
rsp_bufs[0].ptr = bufs[0].ptr;
rsp_bufs[0].size = bufs[0].size;
rsp_bufs[0].offset = bufs[0].offset;
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
rsp_bufs[1].fd = bufs[1].fd;
rsp_bufs[1].ptr = bufs[1].ptr;
rsp_bufs[1].size = bufs[1].size;
rsp_bufs[1].offset = bufs[1].offset;
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
rsp_bufs[2].fd = bufs[2].fd;
rsp_bufs[2].ptr = bufs[2].ptr;
rsp_bufs[2].size = bufs[2].size;
rsp_bufs[2].offset = bufs[2].offset;
rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[3].fd = bufs[3].fd;
rsp_bufs[3].ptr = bufs[3].ptr;
rsp_bufs[3].size = bufs[3].size;
rsp_bufs[3].offset = bufs[3].offset;
rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
rsp_bufs[0].fd = bufs[3].fd;
rsp_bufs[0].ptr = bufs[3].ptr;
rsp_bufs[0].size = bufs[3].size;
rsp_bufs[0].offset = bufs[3].offset;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
@@ -508,32 +474,18 @@ static void proc_matmul_id_req(struct htp_context * ctx,
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
memset(rsp_bufs, 0, sizeof(rsp_bufs));
rsp_bufs[0].fd = bufs[0].fd;
rsp_bufs[0].ptr = bufs[0].ptr;
rsp_bufs[0].offset = bufs[0].offset;
rsp_bufs[0].size = bufs[0].size;
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
rsp_bufs[1].fd = bufs[1].fd;
rsp_bufs[1].ptr = bufs[1].ptr;
rsp_bufs[1].offset = bufs[1].offset;
rsp_bufs[1].size = bufs[1].size;
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[2].fd = bufs[2].fd;
rsp_bufs[2].ptr = bufs[2].ptr;
rsp_bufs[2].offset = bufs[2].offset;
rsp_bufs[2].size = bufs[2].size;
rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
@@ -561,38 +513,18 @@ static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * r
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
memset(rsp_bufs, 0, sizeof(rsp_bufs));
rsp_bufs[0].fd = bufs[0].fd;
rsp_bufs[0].ptr = bufs[0].ptr;
rsp_bufs[0].offset = bufs[0].offset;
rsp_bufs[0].size = bufs[0].size;
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
rsp_bufs[1].fd = bufs[1].fd;
rsp_bufs[1].ptr = bufs[1].ptr;
rsp_bufs[1].offset = bufs[1].offset;
rsp_bufs[1].size = bufs[1].size;
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
rsp_bufs[2].fd = bufs[2].fd;
rsp_bufs[2].ptr = bufs[2].ptr;
rsp_bufs[2].offset = bufs[2].offset;
rsp_bufs[2].size = bufs[2].size;
rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[3].fd = bufs[3].fd;
rsp_bufs[3].ptr = bufs[3].ptr;
rsp_bufs[3].offset = bufs[3].offset;
rsp_bufs[3].size = bufs[3].size;
rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
rsp_bufs[0].fd = bufs[3].fd;
rsp_bufs[0].ptr = bufs[3].ptr;
rsp_bufs[0].offset = bufs[3].offset;
rsp_bufs[0].size = bufs[3].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
@@ -622,26 +554,18 @@ static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * r
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
memset(rsp_bufs, 0, sizeof(rsp_bufs));
rsp_bufs[0].fd = bufs[0].fd;
rsp_bufs[0].ptr = bufs[0].ptr;
rsp_bufs[0].offset = bufs[0].offset;
rsp_bufs[0].size = bufs[0].size;
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[1].fd = bufs[1].fd;
rsp_bufs[1].ptr = bufs[1].ptr;
rsp_bufs[1].offset = bufs[1].offset;
rsp_bufs[1].size = bufs[1].size;
rsp_bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
rsp_bufs[0].fd = bufs[1].fd;
rsp_bufs[0].ptr = bufs[1].ptr;
rsp_bufs[0].offset = bufs[1].offset;
rsp_bufs[0].size = bufs[1].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
@@ -669,7 +593,7 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 2, &prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_activations_req(struct htp_context * ctx,
@@ -677,33 +601,16 @@ static void proc_activations_req(struct htp_context * ctx,
struct dspqueue_buffer * bufs,
uint32_t n_bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
memset(rsp_bufs, 0, sizeof(rsp_bufs));
rsp_bufs[0].fd = bufs[0].fd;
rsp_bufs[0].ptr = bufs[0].ptr;
rsp_bufs[0].offset = bufs[0].offset;
rsp_bufs[0].size = bufs[0].size;
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
int write_idx = 1;
if (3 == n_bufs) {
rsp_bufs[1].fd = bufs[1].fd;
rsp_bufs[1].ptr = bufs[1].ptr;
rsp_bufs[1].offset = bufs[1].offset;
rsp_bufs[1].size = bufs[1].size;
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
write_idx = 2;
}
int write_idx = (n_bufs == 3) ? 2 : 1;
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
rsp_bufs[write_idx].size = bufs[write_idx].size;
rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
rsp_bufs[0].fd = bufs[write_idx].fd;
rsp_bufs[0].ptr = bufs[write_idx].ptr;
rsp_bufs[0].offset = bufs[write_idx].offset;
rsp_bufs[0].size = bufs[write_idx].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
@@ -742,7 +649,7 @@ static void proc_activations_req(struct htp_context * ctx,
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_rope_req(struct htp_context * ctx,
@@ -750,39 +657,16 @@ static void proc_rope_req(struct htp_context * ctx,
struct dspqueue_buffer * bufs,
uint32_t n_bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
memset(rsp_bufs, 0, sizeof(rsp_bufs));
rsp_bufs[0].fd = bufs[0].fd;
rsp_bufs[0].ptr = bufs[0].ptr;
rsp_bufs[0].offset = bufs[0].offset;
rsp_bufs[0].size = bufs[0].size;
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
rsp_bufs[1].fd = bufs[1].fd;
rsp_bufs[1].ptr = bufs[1].ptr;
rsp_bufs[1].offset = bufs[1].offset;
rsp_bufs[1].size = bufs[1].size;
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
int write_idx = 2;
if (4 == n_bufs) {
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
rsp_bufs[write_idx].size = bufs[write_idx].size;
rsp_bufs[write_idx].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
write_idx++;
}
int write_idx = (n_bufs == 4) ? 3 : 2;
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
rsp_bufs[write_idx].size = bufs[write_idx].size;
rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
rsp_bufs[0].fd = bufs[write_idx].fd;
rsp_bufs[0].ptr = bufs[write_idx].ptr;
rsp_bufs[0].offset = bufs[write_idx].offset;
rsp_bufs[0].size = bufs[write_idx].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
@@ -819,7 +703,7 @@ static void proc_rope_req(struct htp_context * ctx,
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {

View File

@@ -29,10 +29,11 @@ if (CXX_IS_HIPCC)
endif()
else()
# Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
if(AMDGPU_TARGETS AND NOT GPU_TARGETS)
set(GPU_TARGETS ${AMDGPU_TARGETS})
endif()
if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS})
elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
endif()
cmake_minimum_required(VERSION 3.21)
enable_language(HIP)

View File

@@ -682,6 +682,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
#endif
#ifdef __cplusplus
#include <array>
#include <initializer_list>
#include <vector>
@@ -697,6 +698,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}
// Return true if the edges in the graph match expectations.
inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
int start_idx,
std::initializer_list<std::array<int, 3>> edges) {
for (const auto & edge : edges) {
int dst_node = edge[0];
int src_idx = edge[1];
int src_node = edge[2];
if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
return false;
}
}
return true;
}
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);

View File

@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
char name[256];
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_ne02=%d", base, ne02);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
@@ -1332,11 +1332,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_neox) {
snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
} else if (is_mrope && !is_vision) {
} else if ((is_mrope || is_imrope) && !is_vision) {
GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
} else if (is_vision) {
@@ -1346,14 +1347,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
}
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
return res;
}

View File

@@ -76,6 +76,7 @@
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
#define FC_MUL_MV 600
#define FC_MUL_MM 700
#define FC_ROPE 800
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8

View File

@@ -3709,6 +3709,8 @@ template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
#endif
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
@@ -3889,14 +3891,26 @@ kernel void kernel_rope_multi(
const int sector = ic % sect_dims;
float theta_base;
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
if (FC_rope_is_imrope) {
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
theta_base = (float) pos[i2 + args.ne02 * 2];
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
theta_base = (float) pos[i2 + args.ne02 * 0];
} else { // e
theta_base = (float) pos[i2 + args.ne02 * 3];
}
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
}
// end of mrope

View File

@@ -6156,8 +6156,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3));
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
sf0 = (float)(ne0 - 1) / (ne00 - 1);
sf1 = (float)(ne1 - 1) / (ne01 - 1);
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
pixel_offset = 0.0f;
}

View File

@@ -79,8 +79,8 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (loadc_a + l < ne01) {
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
if (ir*BM + loadc_a + l < ne01) {
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
@@ -94,7 +94,7 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
}
for (int l = 0; l < BN; l += loadstride_b) {
if (loadc_b + l < ne11) {
if (ic*BN + loadc_b + l < ne11) {
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;

View File

@@ -79,7 +79,7 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (loadc_a + l < ne01) {
if (ir*BM + loadc_a + l < ne01) {
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
@@ -94,7 +94,7 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
}
for (int l = 0; l < BN; l += loadstride_b) {
if (loadc_b + l < ne11) {
if (ic*BN + loadc_b + l < ne11) {
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;

View File

@@ -78,7 +78,7 @@ kernel void kernel_mul_mm_q8_0_f32_l4_lm(
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (loadc_a + l < ne01) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 8;
int iqs = idx % 8;
@@ -101,7 +101,7 @@ kernel void kernel_mul_mm_q8_0_f32_l4_lm(
}
for (int l = 0; l < BN; l += loadstride_b) {
if (loadc_b + l < ne11) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;

View File

@@ -32,8 +32,10 @@
#include "pad.hpp"
#include "quantize.hpp"
#include "quants.hpp"
#include "roll.hpp"
#include "rope.hpp"
#include "set_rows.hpp"
#include "ssm_conv.hpp"
#include "softmax.hpp"
#include "tsembd.hpp"
#include "wkv.hpp"

View File

@@ -42,13 +42,16 @@
#include "ggml-sycl/backend.hpp"
#include "ggml-sycl/common.hpp"
#include "ggml-sycl/element_wise.hpp"
#include "ggml-sycl/norm.hpp"
#include "ggml-sycl/presets.hpp"
#include "ggml-sycl/gemm.hpp"
#include "ggml-sycl/set_rows.hpp"
#include "ggml-sycl/set.hpp"
#include "ggml-sycl/sycl_hw.hpp"
#include "ggml-sycl/getrows.hpp"
#include "ggml-sycl/repeat_back.hpp"
#include "ggml-sycl/quantize.hpp"
#include "ggml-sycl/ssm_conv.hpp"
#include "ggml.h"
static bool g_sycl_loaded = false;
@@ -2615,6 +2618,10 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_repeat_back(ctx, dst);
}
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
@@ -2631,6 +2638,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
ggml_sycl_op_rms_norm(ctx, dst);
}
static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_rms_norm_back(ctx, dst);
}
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_l2_norm(ctx, dst);
@@ -3679,6 +3691,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_REPEAT:
ggml_sycl_repeat(ctx, dst);
break;
case GGML_OP_REPEAT_BACK:
ggml_sycl_repeat_back(ctx, dst);
break;
case GGML_OP_GET_ROWS:
ggml_sycl_get_rows(ctx, dst);
break;
@@ -3818,6 +3833,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_LEAKY_RELU:
ggml_sycl_leaky_relu(ctx, dst);
break;
case GGML_OP_RMS_NORM_BACK:
ggml_sycl_rms_norm_back(ctx, dst);
break;
case GGML_OP_RMS_NORM:
ggml_sycl_rms_norm(ctx, dst);
break;
@@ -3913,6 +3931,11 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_sycl_ssm_conv(ctx, dst);
case GGML_OP_ROLL:
ggml_sycl_roll(ctx, dst);
break;
case GGML_OP_ARANGE:
ggml_sycl_arange(ctx, dst);
break;
@@ -4516,6 +4539,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
ggml_type src0_type = op->src[0]->type;
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
}
case GGML_OP_REPEAT_BACK:
{
ggml_type src0_type = op->src[0]->type;
return src0_type == GGML_TYPE_F32;
}
case GGML_OP_DUP:
case GGML_OP_ARGMAX:
case GGML_OP_NONE:
@@ -4552,6 +4580,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return ggml_is_contiguous(op->src[0]);
case GGML_OP_RMS_NORM:
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
case GGML_OP_RMS_NORM_BACK:
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
case GGML_OP_SCALE:
return true;
case GGML_OP_CONT:
@@ -4586,6 +4616,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
case GGML_OP_SSM_CONV:
return op->type == GGML_TYPE_F32 &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_ROLL:
return op->type == GGML_TYPE_F32;
case GGML_OP_ARANGE:
return op->type == GGML_TYPE_F32;
default:

View File

@@ -480,6 +480,162 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
}
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); // dz
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); // x
GGML_ASSERT(dst->type == GGML_TYPE_F32);
float eps = 1e-5f;
std::memcpy(&eps, dst->op_params, sizeof(float));
if (!(eps > 0.0f) || !std::isfinite(eps)) eps = 1e-5f;
const float * g_base = static_cast<const float *>(dst->src[0]->data); // dz
const float * x_base = static_cast<const float *>(dst->src[1]->data); // x
float * dx_base = static_cast< float *>(dst->data);
const int64_t D = dst->ne[0];
const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3;
const int64_t N = ggml_nrows(dst);
if (D == 0 || N == 0) return;
const ggml_tensor *G = dst->src[0];
const ggml_tensor *X = dst->src[1];
const int ts = (int) ggml_type_size(X->type);
GGML_ASSERT((size_t) X->nb[0] == (size_t) ts);
GGML_ASSERT((size_t) G->nb[0] == (size_t) ts);
GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts);
const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts;
const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts;
const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts;
dpct::queue_ptr q = ctx.stream();
// work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D
const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device];
auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; };
int wg_cap = 256;
if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg);
int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min<int64_t>(D, wg_cap), WARP_SIZE), wg_cap));
// FP32 path: per-thread compensated accumulation + hierarchical reduction
q->submit([&](sycl::handler &cgh) {
const int nwarps_loc = std::max(1, WG / WARP_SIZE);
// store one partial value per warp (xx and xg) for cross-warp reduction
auto l_xx = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
auto l_xg = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG),
sycl::range<3>(1, 1, WG)),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
const int row = item_ct1.get_group(2);
const int tid = item_ct1.get_local_id(2);
const int64_t i1 = row % n1;
const int64_t i2 = (row / n1) % n2;
const int64_t i3 = row / (n1 * n2);
const float *__restrict x_row = x_base + i3 * xs3 + i2 * xs2 + i1 * xs1;
const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1;
float *__restrict d_row = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1;
// per-thread accumulation (compensated by default)
float sum_xx = 0.f, sum_xg = 0.f;
#ifndef GGML_SYCL_RMS_BACK_FAST
float c_xx = 0.f, c_xg = 0.f;
#endif
for (int64_t col = tid; col < D; col += WG) {
const float xv = x_row[col];
const float gv = g_row[col];
#ifdef GGML_SYCL_RMS_BACK_FAST
sum_xx += xv * xv;
sum_xg += xv * gv;
#else
float y1 = xv * xv - c_xx;
float t1 = sum_xx + y1;
c_xx = (t1 - sum_xx) - y1;
sum_xx = t1;
float y2 = xv * gv - c_xg;
float t2 = sum_xg + y2;
c_xg = (t2 - sum_xg) - y2;
sum_xg = t2;
#endif
}
// warp-level reduction
sycl::float2 xx = sycl::float2(sum_xx,
#ifndef GGML_SYCL_RMS_BACK_FAST
c_xx
#else
0.f
#endif
);
sycl::float2 xg = sycl::float2(sum_xg,
#ifndef GGML_SYCL_RMS_BACK_FAST
c_xg
#else
0.f
#endif
);
xx = warp_reduce_sum(xx, item_ct1);
xg = warp_reduce_sum(xg, item_ct1);
// cross-warp reduction using local memory (single barrier)
const auto sub_group = item_ct1.get_sub_group();
const auto sg_id = sub_group.get_group_linear_id();
const auto wi_in_sg = sub_group.get_local_linear_id();
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
sycl::float2 xx_total = xx;
sycl::float2 xg_total = xg;
if (nwarps > 1) {
if (wi_in_sg == 0) {
l_xx[sg_id] = xx;
l_xg[sg_id] = xg;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
if (sg_id == 0) {
const unsigned wi_u = wi_in_sg;
sycl::float2 xx_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f);
sycl::float2 xg_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f);
xx_total = warp_reduce_sum(xx_first, item_ct1);
xg_total = warp_reduce_sum(xg_first, item_ct1);
} else {
// other subgroups keep their local totals; they'll be ignored
xx_total = xx;
xg_total = xg;
}
// ensure all threads see the first-subgroup result via broadcast below
}
// compute inv_r and coeff once per row and broadcast to the whole work-group
float inv_r = 0.f;
float coeff = 0.f;
if (tid == 0) {
const float sum_xx_f = xx_total.x() + xx_total.y();
const float sum_xdz_f = xg_total.x() + xg_total.y();
const float mean_eps = sum_xx_f / (float) D + eps;
const float sum_eps = sum_xx_f + eps * (float) D;
inv_r = sycl::rsqrt(mean_eps);
coeff = -sum_xdz_f / sum_eps;
}
inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r);
coeff = sycl::group_broadcast(item_ct1.get_group(), coeff);
for (int64_t col = tid; col < D; col += WG) {
d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r;
}
});
});
}
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);

View File

@@ -19,6 +19,8 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);

View File

@@ -0,0 +1,56 @@
#include "repeat_back.hpp"
#include "common.hpp"
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const float * src0_dd = (const float *) dst->src[0]->data;
float * dst_dd = (float *) dst->data;
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
ne03 = dst->src[0]->ne[3];
const int nr0 = (int) (ne00 / ne0);
const int nr1 = (int) (ne01 / ne1);
const int nr2 = (int) (ne02 / ne2);
const int nr3 = (int) (ne03 / ne3);
const size_t total = ne0 * ne1 * ne2 * ne3;
const int BLOCK_SIZE = 256;
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
queue_ptr stream = ctx.stream();
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
const size_t i = item_ct1.get_global_linear_id();
if (i >= total) {
return;
}
const int i0 = i % ne0;
const int i1 = (i / ne0) % ne1;
const int i2 = (i / (ne0 * ne1)) % ne2;
const int i3 = i / (ne0 * ne1 * ne2);
float acc = 0.0f;
for (int j3 = 0; j3 < nr3; ++j3) {
for (int j2 = 0; j2 < nr2; ++j2) {
for (int j1 = 0; j1 < nr1; ++j1) {
for (int j0 = 0; j0 < nr0; ++j0) {
acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
(i3 + j3 * ne3) * ne00 * ne01 * ne02];
}
}
}
}
dst_dd[i] = acc;
});
}

View File

@@ -0,0 +1,8 @@
#ifndef GGML_SYCL_REPEAT_BACK_HPP
#define GGML_SYCL_REPEAT_BACK_HPP
#include "common.hpp"
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_REPEAT_BACK_HPP

122
ggml/src/ggml-sycl/roll.cpp Normal file
View File

@@ -0,0 +1,122 @@
#include "roll.hpp"
#include "common.hpp"
using namespace sycl;
static inline int wrap_add(int i, int shift, int n) {
int s = i + shift;
return (s >= n) ? (s - n) : s;
}
static void kernel_roll_fused_i0_i1(
queue &q,
const float *src_d,
float *dst_d,
int ne0, int ne1, int ne2, int ne3,
int sh0, int sh1, int sh2, int sh3)
{
if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;
const int stride1 = ne0;
const int stride2 = ne0 * ne1;
const int stride3 = ne0 * ne1 * ne2;
const int shNe0 = (ne0 - sh0) % ne0;
const int shNe1 = (ne1 - sh1) % ne1;
const int shNe2 = (ne2 - sh2) % ne2;
const int shNe3 = (ne3 - sh3) % ne3;
const size_t g0 = (size_t) ne3;
const size_t g1 = (size_t) ne2;
const size_t g2 = (size_t) (ne1 * ne0);
const range<3> global{ g0, g1, g2 };
q.submit([&](handler &h) {
h.parallel_for(global, [=](id<3> idx) {
const int i3 = (int) idx[0];
const int i2 = (int) idx[1];
const int fused = (int) idx[2];
const int i1 = fused / ne0;
const int i0 = fused - i1 * ne0; // fused % ne0
const int idx_dst = i0
+ i1 * stride1
+ i2 * stride2
+ i3 * stride3;
const int s0 = wrap_add(i0, shNe0, ne0);
const int s1 = wrap_add(i1, shNe1, ne1);
const int s2 = wrap_add(i2, shNe2, ne2);
const int s3 = wrap_add(i3, shNe3, ne3);
const int idx_src = s0
+ s1 * stride1
+ s2 * stride2
+ s3 * stride3;
dst_d[idx_dst] = src_d[idx_src];
});
});
}
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const ggml_tensor *src = dst->src[0];
GGML_ASSERT(src && src->type == GGML_TYPE_F32);
const int ne0 = (int) dst->ne[0];
const int ne1 = (int) dst->ne[1];
const int ne2 = (int) dst->ne[2];
const int ne3 = (int) dst->ne[3];
const int32_t *params = (const int32_t *) dst->op_params;
int shift0 = params[0];
int shift1 = params[1];
int shift2 = params[2];
int shift3 = params[3];
if ((shift0 | shift1 | shift2 | shift3) == 0) {
const size_t nb = ggml_nbytes(src);
queue *q = ctx.stream();
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
return;
}
auto norm = [](int sh, int n) -> int {
if (n <= 0) return 0;
sh %= n;
if (sh < 0) sh += n;
return sh;
};
shift0 = norm(shift0, ne0);
shift1 = norm(shift1, ne1);
shift2 = norm(shift2, ne2);
shift3 = norm(shift3, ne3);
try {
queue *q = ctx.stream();
const float *src_d = (const float *) src->data;
float *dst_d = (float *) dst->data;
GGML_ASSERT(src_d && dst_d);
kernel_roll_fused_i0_i1(
*q, src_d, dst_d,
ne0, ne1, ne2, ne3,
shift0, shift1, shift2, shift3
);
} catch (const std::exception &e) {
std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
throw;
}
}

View File

@@ -0,0 +1,20 @@
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_ROLL_HPP
#define GGML_SYCL_ROLL_HPP
#include "common.hpp"
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
#endif // GGML_SYCL_ROLL_HPP

View File

@@ -119,7 +119,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections,
const sycl::nd_item<3> & item_ct1) {
const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
// get index pos
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
if (i0 >= ne0) {
@@ -143,17 +143,29 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
float theta_base = 0.0;
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
} else {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
}
}
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -281,7 +293,7 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
const float freq_scale, const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
const mrope_sections sections, queue_ptr stream) {
const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
@@ -297,12 +309,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
} else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
}
}
@@ -381,6 +393,7 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
@@ -422,11 +435,11 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
if (dst->src[0]->type == GGML_TYPE_F16) {
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, sections, main_stream);
freq_factors, sections, is_imrope, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F32) {
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
main_stream);
is_imrope, main_stream);
} else {
GGML_ABORT("Fatal error: Tensor type unsupported!");
}

View File

@@ -0,0 +1,127 @@
#include "ssm_conv.hpp"
#include "common.hpp"
#include <cstdio>
using namespace sycl;
static void kernel_ssm_conv(
queue &q,
const float *src_data,
const float *weights,
float *dst_data,
int d_conv,
int d_inner,
int n_t,
int n_s,
int ncs __attribute__((unused)),
int src_stride_inner,
int src_stride_seq,
int dst_stride_token,
int dst_stride_seq
) {
const size_t total_work = static_cast<size_t>(d_inner) * static_cast<size_t>(n_t) * static_cast<size_t>(n_s);
const size_t work_group_size = 256;
const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size;
const range<1> global_range(num_work_groups * work_group_size);
const range<1> local_range(work_group_size);
q.submit([&](handler &h) {
h.parallel_for(
nd_range<1>(global_range, local_range),
[=](nd_item<1> item) {
const size_t idx = item.get_global_id(0);
if (idx >= total_work) {
return;
}
const int channel = static_cast<int>(idx % d_inner);
const int token = static_cast<int>((idx / d_inner) % n_t);
const int seq = static_cast<int>(idx / (static_cast<size_t>(d_inner) * static_cast<size_t>(n_t)));
const float *s = src_data
+ static_cast<size_t>(seq) * static_cast<size_t>(src_stride_seq)
+ static_cast<size_t>(channel) * static_cast<size_t>(src_stride_inner)
+ static_cast<size_t>(token);
const float *c = weights + static_cast<size_t>(channel) * static_cast<size_t>(d_conv);
float sumf = 0.0f;
for (int i0 = 0; i0 < d_conv; ++i0) {
sumf += s[i0] * c[i0];
}
const size_t dst_idx =
static_cast<size_t>(seq) * static_cast<size_t>(dst_stride_seq) +
static_cast<size_t>(token) * static_cast<size_t>(dst_stride_token) +
static_cast<size_t>(channel);
dst_data[dst_idx] = sumf;
}
);
});
}
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int d_conv = src1->ne[0];
const int ncs = src0->ne[0];
const int d_inner = src0->ne[1];
const int n_t = dst->ne[1];
const int n_s = dst->ne[2];
GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t);
GGML_ASSERT(src0->ne[1] == d_inner);
GGML_ASSERT(src1->ne[1] == d_inner);
GGML_ASSERT(dst->ne[0] == d_inner);
GGML_ASSERT(dst->ne[1] == n_t);
GGML_ASSERT(dst->ne[2] == n_s);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0] * static_cast<int>(sizeof(float)));
const int src_stride_inner = ncs;
const int src_stride_seq = ncs * d_inner;
const int dst_stride_token = d_inner;
const int dst_stride_seq = d_inner * n_t;
try {
queue *q = ctx.stream();
const float *src_data = static_cast<const float *>(src0->data);
const float *weights = static_cast<const float *>(src1->data);
float *dst_data = static_cast<float *>(dst->data);
GGML_ASSERT(src_data && weights && dst_data);
kernel_ssm_conv(
*q,
src_data,
weights,
dst_data,
d_conv,
d_inner,
n_t,
n_s,
ncs,
src_stride_inner,
src_stride_seq,
dst_stride_token,
dst_stride_seq
);
} catch (const std::exception &e) {
std::fprintf(stderr, "[SYCL-SSM_CONV] ERROR: %s\n", e.what());
throw;
}
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include "common.hpp"
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

File diff suppressed because it is too large Load Diff

View File

@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
uint nrows;
uint order;
} p;
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
dst_row[idx1] = tmp;
}
void argsort(bool needs_bounds_check) {
void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
void main() {
if (p.ncols == BLOCK_SIZE) {
argsort(false);
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
argsort(false, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
} else {
argsort(true);
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
argsort(true, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}
}

View File

@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
const uint scales = data_a[a_offset + ib].scales[scalesi];
const vec2 d = vec2(data_a[a_offset + ib].d);
const vec2 dm = vec2(data_a[a_offset + ib].dm);
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint8_t hm = uint8_t(1 << (iqs / 16));
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);

View File

@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
const f16vec2 d = bl.block.d;
const f16vec2 dm = bl.block.dm;
const uint idx = coordInBlock[1];
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
qs = unpack8(qs)[idx & 1];
const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
return ret;
}
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
return ret;
}
#endif

View File

@@ -26,7 +26,7 @@ void main() {
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
}
}

View File

@@ -24,8 +24,8 @@ void main() {
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));

View File

@@ -20,8 +20,8 @@ void main() {
const uint is = 2 * il;
const uint n = 4;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;

View File

@@ -19,8 +19,8 @@ void main() {
const uint ir = tid % 16;
const uint is = 2 * il;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;

View File

@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
}
}
}

View File

@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}

View File

@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}

View File

@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[BN];
uint _ne1;
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_LocalInvocationIndex;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
ballots_sh[gl_SubgroupID] = ballot;
barrier();
uint subgroup_base = 0;
uint total = 0;
for (uint k = 0; k < gl_NumSubgroups; ++k) {
if (k == gl_SubgroupID) {
subgroup_base = total;
}
total += subgroupBallotBitCount(ballots_sh[k]);
}
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
}
barrier();
}
#endif // MUL_MAT_ID_USE_SUBGROUPS
#endif // MUL_MAT_ID
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
#include "mul_mm_id_funcs.glsl"
#include "mul_mm_funcs.glsl"
void main() {

View File

@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
const uint scales = data_a[ib].scales[scalesi];
const vec2 d = vec2(data_a[ib].d);
const vec2 dm = vec2(data_a[ib].dm);
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_Q3_K)
@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(data_a[ib].d);
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint8_t hm = uint8_t(1 << (iqs / 16));
const vec2 loadd = vec2(data_a[ib].d);
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;
const float d = e8m0_to_fp32(data_a[ib].e);
const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);

View File

@@ -0,0 +1,70 @@
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[BN];
uint _ne1;
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_LocalInvocationIndex;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
ballots_sh[gl_SubgroupID] = ballot;
barrier();
uint subgroup_base = 0;
uint total = 0;
for (uint k = 0; k < gl_NumSubgroups; ++k) {
if (k == gl_SubgroupID) {
subgroup_base = total;
}
total += subgroupBallotBitCount(ballots_sh[k]);
}
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
}
barrier();
}
#endif // MUL_MAT_ID_USE_SUBGROUPS
#endif // MUL_MAT_ID

View File

@@ -10,10 +10,9 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#if defined(MUL_MAT_ID_USE_SUBGROUPS)
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
@@ -24,7 +23,10 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
@@ -76,40 +78,31 @@ layout (constant_id = 10) const uint WARP = 32;
#define BK 32
#ifdef COOPMAT
#define SHMEM_STRIDE (BK / 4 + 4)
#else
#define SHMEM_STRIDE (BK / 4 + 1)
#endif
#define MMQ_SHMEM
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
#ifndef COOPMAT
#if QUANT_AUXF == 1
shared FLOAT_TYPE buf_a_dm[BM];
#else
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
#endif
#endif
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
#ifndef COOPMAT
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
#endif
#define LOAD_VEC_A (4 * QUANT_R)
#define LOAD_VEC_B 16
#include "mul_mmq_shmem_types.glsl"
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
#endif // MUL_MAT_ID
#define BK_STEP 1
#else
#ifndef BK_STEP
#define BK_STEP 4
#endif
#endif
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
shared block_b_cache buf_b[BN * BK_STEP];
// Register cache
block_a_cache cache_a[WMITER * TM];
block_b_cache cache_b;
#define LOAD_VEC_A (4 * QUANT_R_MMQ)
#define LOAD_VEC_B 16
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
#include "mul_mm_id_funcs.glsl"
#include "mul_mmq_funcs.glsl"
void main() {
@@ -139,26 +132,12 @@ void main() {
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
#ifdef COOPMAT
const uint warp_i = gl_SubgroupID;
const uint tiw = gl_SubgroupInvocationID;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint storestride = WARP / TM;
const uint store_r = tiw % TM;
const uint store_c = tiw / TM;
#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
@@ -172,17 +151,27 @@ void main() {
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
uint _ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
#ifdef MUL_MAT_ID_USE_SUBGROUPS
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true, ic);
} else {
load_row_ids(expert_idx, false, ic);
}
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
row_ids[_ne1] = u16vec2(ii0, ii1);
if (_ne1 >= ic * BN) {
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
}
_ne1++;
}
}
}
barrier();
#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;
@@ -209,159 +198,70 @@ void main() {
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
#endif
#ifdef COOPMAT
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
int32_t cache_a_qs[WMITER * TM * BK / 4];
int32_t cache_b_qs[TN * BK / 4];
ACC_TYPE sums[WMITER * TM * WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
}
#endif
#if QUANT_AUXF == 1
FLOAT_TYPE cache_a_dm[WMITER * TM];
#else
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
#endif
FLOAT_TYPE_VEC2 cache_b_ds[TN];
for (uint block = start_k; block < end_k; block += BK) {
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l;
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
const uint iqs = loadr_a;
if (iqs == 0) {
#if QUANT_AUXF == 1
buf_a_dm[buf_ib] = get_d(ib);
#else
buf_a_dm[buf_ib] = get_dm(ib);
#endif
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
}
#if QUANT_R == 1
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
#else
const i32vec2 vals = repack(ib, iqs);
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
#endif
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
const uint ib = idx / 8;
const uint iqs = idx & 0x7;
#else
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
const uint iqs = loadr_b;
#endif
const uint buf_ib = loadc_b + l;
if (iqs == 0) {
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[buf_ib];
const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
#else
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
#endif
const uint iqs = loadr_b;
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
}
barrier();
pos_a_ib += 1;
pos_b_ib += 1;
pos_a_ib += BK_STEP;
pos_b_ib += BK_STEP;
#ifdef COOPMAT
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint ib_a = warp_r * WM + cm_row * TM;
for (uint k_step = 0; k_step < BK_STEP; k_step++) {
// Load from shared into cache
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint ib_b = warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
}
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
}
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
}
}
#else
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
}
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
cache_b_ds[cc] = buf_b_ds[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
}
}
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
int32_t q_sum = 0;
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
cache_b_qs[cc * (BK / 4) + idx_k]);
}
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint reg_ib = wsir * TM + cr;
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
block_a_to_registers(reg_ib, k_step * BM + buf_ib);
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
block_b_to_registers(ib);
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] += mmq_dot_product(cache_a_idx);
}
}
}
}
}
#endif
barrier();
}
@@ -373,54 +273,6 @@ void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
#ifdef COOPMAT
#ifdef MUL_MAT_ID
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
#else
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
if (is_aligned && is_in_bounds) {
// Full coopMat is within bounds and stride_d is aligned with 16B
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
} else if (is_in_bounds) {
// Full coopMat is within bounds, but stride_d is not aligned
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
// Partial coopMat is within bounds
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
}
}
#endif // MUL_MAT_ID
#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@@ -431,19 +283,21 @@ void main() {
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
#ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
if (dr_warp + cr < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
}
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
}
#endif // MUL_MAT_ID
}
}
}
}
#endif // COOPMAT
}

View File

@@ -6,41 +6,89 @@
// Each iqs value maps to a 32-bit integer
#if defined(DATA_A_Q4_0)
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
// 2-byte loads for Q4_0 blocks (18 bytes)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
#ifdef DATA_A_Q4_0
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#else // DATA_A_Q4_1
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#endif
}
#ifdef DATA_A_Q4_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q4_1)
i32vec2 repack(uint ib, uint iqs) {
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
#else // DATA_A_Q4_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#if defined(DATA_A_Q5_0)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q4_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
}
#else // DATA_A_Q4_1
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
const uint32_t vui = cache_a[ib_a].qs[iqs];
const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
const int32_t qs_b0 = cache_b.qs[iqs];
const int32_t qs_b1 = cache_b.qs[iqs + 4];
q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
// 2-byte loads for Q5_0 blocks (22 bytes)
// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
#ifdef DATA_A_Q5_0
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
#else // DATA_A_Q5_1
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
#endif
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
@@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) {
return i32vec2(v0, v1);
}
#ifdef DATA_A_Q5_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q5_1)
i32vec2 repack(uint ib, uint iqs) {
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
const uint32_t vui = data_a_packed32[ib].qs[iqs];
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
#else // DATA_A_Q5_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q5_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
}
#else // DATA_A_Q5_1
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
buf_a[buf_ib].qh = data_a_packed32[ib].qh;
}
#endif
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
cache_a[reg_ib].qh = buf_a[buf_ib].qh;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
const uint32_t vui = cache_a[ib_a].qs[iqs];
const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
const int32_t qs_b0 = cache_b.qs[iqs];
const int32_t qs_b1 = cache_b.qs[iqs + 4];
q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]));
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
const int32_t qs_b = cache_b.qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, qs_b);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
return i32vec2( quants & 0x0F0F0F0F,
(quants >> 4) & 0x0F0F0F0F);
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * dsb.x * float(q_sum));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
if (iqs == 0) {
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].d = buf_a[buf_ib].d;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
int32_t repack(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
// Repack 4x4 quants into one int
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t sum_d = 0;
int32_t sum_m = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
}
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint hm_idx = iqs * QUANT_R_MMQ;
const uint iqs_k = (ib % 8) * 8 + hm_idx;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
const uint hm_shift = iqs_k / 8;
// Repack 2x4 quants into one int
// Add the 3rd bit instead of subtracting it to allow packing the quants
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
if (iqs == 0) {
const uint is = iqs_k / 4;
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
float result = 0.0;
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
// Subtract 4 from the quants to correct the 3rd bit offset
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
q_sum = 0;
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
// Repack 2x4 quants into one int
#if defined(DATA_A_Q4_K)
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
#else // defined(DATA_A_Q5_K)
const uint qh_idx = iqs * QUANT_R_MMQ;
const uint qh_shift = iqs_k / 8;
buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
#endif
if (iqs == 0) {
// Scale index
const uint is = iqs_k / 8;
u8vec2 scale_dm;
if (is < 4) {
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
} else {
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
[[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
#if defined(DATA_A_Q4_K)
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
#else // defined(DATA_A_Q5_K)
const int32_t qs_a = cache_a[ib_a].qs[iqs];
#endif
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#ifdef MMQ_SHMEM
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
}
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
if (iqs == 0) {
const uint is = iqs_k / 4;
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
float result = 0.0;
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
q_sum = 0;
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
@@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
}
#endif

View File

@@ -0,0 +1,78 @@
#if defined(DATA_A_Q4_0)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_Q4_1)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q5_0)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
uint32_t qh;
FLOAT_TYPE dm;
};
#elif defined(DATA_A_Q5_1)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
uint32_t qh;
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2
// #define BK_STEP 1
struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_MXFP4)
#define QUANT_R_MMQ 2
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE d;
};
#elif defined(DATA_A_Q2_K)
#define QUANT_R_MMQ 4
struct block_a_cache {
uint32_t qs[2];
u8vec2 scales;
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q3_K)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[4];
FLOAT_TYPE_VEC2 d_scales;
};
#elif defined(DATA_A_Q4_K)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[4];
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q5_K)
#define QUANT_R_MMQ 1
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q6_K)
#define QUANT_R_MMQ 1
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE_VEC2 d_scales;
};
#endif
struct block_b_cache
{
int32_t qs[8];
FLOAT_TYPE_VEC2 ds;
};

View File

@@ -10,6 +10,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {int data_pos[];};
layout (binding = 2) readonly buffer Z {float data_ff[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
layout (push_constant) uniform parameter {
uint ncols;
@@ -26,7 +27,9 @@ layout (push_constant) uniform parameter {
uint s1;
uint s2;
int sections[4];
uint is_imrope;
uint is_back;
uint set_rows_stride;
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {

View File

@@ -32,17 +32,29 @@ void main() {
const uint sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
if (p.is_imrope != 0) {
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
} else {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
} else {
if (sector < p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
}
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

View File

@@ -16,12 +16,19 @@ void main() {
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2;
uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
// Fusion optimization: ROPE + VIEW + SET_ROWS..
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
if (p.set_rows_stride != 0) {
idst = row_x*ne0 + i0/2;
idst += data_i[channel_x].x * p.set_rows_stride;
}
if (i0 >= p.n_dims) {
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
return;
}

View File

@@ -16,12 +16,19 @@ void main() {
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0;
uint idst = row_dst*ne0 + i0;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
// Fusion optimization: ROPE + VIEW + SET_ROWS..
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
if (p.set_rows_stride != 0) {
idst = row_x*ne0 + i0;
idst += data_i[channel_x].x * p.set_rows_stride;
}
if (i0 >= p.n_dims) {
data_d[idst + 0] = data_a[ix + 0];
data_d[idst + 1] = data_a[ix + 1];
data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
return;
}

View File

@@ -1,6 +1,9 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
#include "types.glsl"
@@ -84,35 +87,47 @@ void main() {
}
barrier();
for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
[[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
const uint k = (tid % (w >> 1)) +
(D_STATE * (tid / (w >> 1))) +
j * D_STATE * (D_STATE / (w >> 1));
if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
stateC[k] += stateC[k + (w >> 1)];
[[unroll]]
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
stateC[k] += stateC[k + w];
}
}
barrier();
}
[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
const uint idx = (tid % SUBGROUP_SIZE) +
D_STATE * (tid / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
const uint max_idx = SUBGROUP_SIZE - 1 +
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
uint lane = tid % SUBGROUP_SIZE;
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
if (idx + offset < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
if (idx < SPLIT_H * D_STATE ||
max_idx < SPLIT_H * D_STATE) {
float sc;
#if USE_SUBGROUP_ADD
sc = stateC[idx];
sc = subgroupAdd(sc);
#else
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
if (idx + offset < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
}
barrier();
}
barrier();
}
if (tid % SUBGROUP_SIZE == 0) {
sc = stateC[idx];
}
#endif
if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = stateC[idx];
if (tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = sc;
}
}
}

View File

@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
float clamp_min;
float clamp_max;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
layout(constant_id = 3) const bool late_softmax = false;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
@@ -25,6 +28,52 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
const float INFINITY = 1.0 / 0.0;
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
float max_val = -INFINITY;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
max_val = max(max_val, vals[i]);
}
}
max_val = subgroupMax(max_val);
float sum = 0.f;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
const float val = exp(vals[i] - max_val);
vals[i] = val;
sum += val;
} else {
vals[i] = 0.f;
}
}
sum = subgroupAdd(sum);
const float inv_sum = 1.0f / sum;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
vals[i] *= inv_sum;
}
}
}
void main() {
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
if (row >= n_rows) {
@@ -35,43 +84,16 @@ void main() {
const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row;
float logits_r[experts_per_thread];
const float INFINITY = 1.0 / 0.0;
float wt[experts_per_thread];
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + gl_LocalInvocationID.x;
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
const uint expert = i + gl_LocalInvocationID.x;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
}
float max_val = logits_r[0];
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
}
max_val = subgroupMax(max_val);
float wt[experts_per_thread];
float tmp = 0.f;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = exp(val - max_val);
tmp += wt[i];
}
tmp = subgroupAdd(tmp);
const float inv_sum = 1.0f / tmp;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
if (!late_softmax) {
softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
}
// at this point, each thread holds a portion of softmax,
@@ -82,6 +104,11 @@ void main() {
float output_weights[experts_per_thread];
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
output_weights[i] = 0.f;
}
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
@@ -121,6 +148,7 @@ void main() {
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
@@ -129,6 +157,10 @@ void main() {
}
}
if (late_softmax) {
softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
}
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;

View File

@@ -66,6 +66,7 @@ struct block_q4_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q4_0
#define A_TYPE_PACKED16 block_q4_0_packed16
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q4_1 32
@@ -98,6 +99,7 @@ struct block_q4_1_packed32
#define A_TYPE block_q4_1
#define A_TYPE_PACKED16 block_q4_1_packed16
#define A_TYPE_PACKED32 block_q4_1_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_0 32
@@ -123,6 +125,7 @@ struct block_q5_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q5_0
#define A_TYPE_PACKED16 block_q5_0_packed16
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_1 32
@@ -158,6 +161,7 @@ struct block_q5_1_packed32
#define A_TYPE block_q5_1
#define A_TYPE_PACKED16 block_q5_1_packed16
#define A_TYPE_PACKED32 block_q5_1_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_0 32
@@ -186,6 +190,7 @@ struct block_q8_0_packed32
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
#define A_TYPE_PACKED32 block_q8_0_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_1 32
@@ -226,21 +231,21 @@ struct block_q2_K
{
uint8_t scales[QUANT_K_Q2_K/16];
uint8_t qs[QUANT_K_Q2_K/4];
f16vec2 d;
f16vec2 dm;
};
struct block_q2_K_packed16
{
uint16_t scales[QUANT_K_Q2_K/16/2];
uint16_t qs[QUANT_K_Q2_K/4/2];
f16vec2 d;
f16vec2 dm;
};
struct block_q2_K_packed32
{
uint32_t scales[QUANT_K_Q2_K/16/4];
uint32_t qs[QUANT_K_Q2_K/4/4];
f16vec2 d;
f16vec2 dm;
};
#if defined(DATA_A_Q2_K)
@@ -249,6 +254,8 @@ struct block_q2_K_packed32
#define A_TYPE block_q2_K
#define A_TYPE_PACKED16 block_q2_K_packed16
#define A_TYPE_PACKED32 block_q2_K_packed32
#define SCALES_PER_32 2
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q3_K 256
@@ -274,27 +281,28 @@ struct block_q3_K_packed16
#define QUANT_R 1
#define A_TYPE block_q3_K
#define A_TYPE_PACKED16 block_q3_K_packed16
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q4_K 256
struct block_q4_K
{
f16vec2 d;
f16vec2 dm;
uint8_t scales[3*QUANT_K_Q4_K/64];
uint8_t qs[QUANT_K_Q4_K/2];
};
struct block_q4_K_packed16
{
f16vec2 d;
f16vec2 dm;
uint16_t scales[3*QUANT_K_Q4_K/64/2];
uint16_t qs[QUANT_K_Q4_K/2/2];
};
struct block_q4_K_packed32
{
f16vec2 d;
f16vec2 dm;
uint32_t scales[3*QUANT_K_Q4_K/64/4];
uint32_t qs[QUANT_K_Q4_K/2/4];
};
@@ -310,13 +318,14 @@ struct block_q4_K_packed128
#define A_TYPE block_q4_K
#define A_TYPE_PACKED16 block_q4_K_packed16
#define A_TYPE_PACKED32 block_q4_K_packed32
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q5_K 256
struct block_q5_K
{
f16vec2 d;
f16vec2 dm;
uint8_t scales[12];
uint8_t qh[QUANT_K_Q5_K/8];
uint8_t qs[QUANT_K_Q5_K/2];
@@ -324,12 +333,20 @@ struct block_q5_K
struct block_q5_K_packed16
{
f16vec2 d;
f16vec2 dm;
uint16_t scales[12/2];
uint16_t qh[QUANT_K_Q5_K/8/2];
uint16_t qs[QUANT_K_Q5_K/2/2];
};
struct block_q5_K_packed32
{
f16vec2 dm;
uint32_t scales[12/4];
uint32_t qh[QUANT_K_Q5_K/8/4];
uint32_t qs[QUANT_K_Q5_K/2/4];
};
struct block_q5_K_packed128
{
uvec4 q5k[11];
@@ -340,6 +357,8 @@ struct block_q5_K_packed128
#define QUANT_R 1
#define A_TYPE block_q5_K
#define A_TYPE_PACKED16 block_q5_K_packed16
#define A_TYPE_PACKED32 block_q5_K_packed32
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q6_K 256
@@ -356,7 +375,7 @@ struct block_q6_K_packed16
{
uint16_t ql[QUANT_K_Q6_K/2/2];
uint16_t qh[QUANT_K_Q6_K/4/2];
int8_t scales[QUANT_K_Q6_K/16];
int16_t scales[QUANT_K_Q6_K/16/2];
float16_t d;
};
@@ -365,6 +384,7 @@ struct block_q6_K_packed16
#define QUANT_R 1
#define A_TYPE block_q6_K
#define A_TYPE_PACKED16 block_q6_K_packed16
#define DATA_A_QUANT_K
#endif
// IQuants
@@ -1363,18 +1383,11 @@ struct block_mxfp4
uint8_t qs[QUANT_K_MXFP4/2];
};
//struct block_mxfp4_packed16
//{
// uint8_t e;
// uint16_t qs[QUANT_K_MXFP4/2/2];
//};
#if defined(DATA_A_MXFP4)
#define QUANT_K QUANT_K_MXFP4
#define QUANT_R QUANT_R_MXFP4
#define QUANT_AUXF 1
#define A_TYPE block_mxfp4
//#define A_TYPE_PACKED16 block_mxfp4_packed16
#endif
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
@@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize)
#endif
#if defined(DATA_A_MXFP4)
const FLOAT_TYPE kvalues_mxfp4_const[16] = {
FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
const int8_t kvalues_mxfp4_const[16] = {
int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
};
shared FLOAT_TYPE kvalues_mxfp4[16];
shared int8_t kvalues_mxfp4[16];
#define NEEDS_INIT_IQ_SHMEM
void init_iq_shmem(uvec3 wgsize)

View File

@@ -7,6 +7,7 @@ layout (push_constant) uniform parameter
uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13;
float sf0; float sf1; float sf2; float sf3;
float pixel_offset;
} p;
#include "types.glsl"
@@ -19,7 +20,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
#define NEAREST 0
#define BILINEAR 1
#define ALIGN_CORNERS (1 << 8)
layout (constant_id = 0) const uint scale_mode = 0;
@@ -52,7 +52,7 @@ float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
const ivec2 ne0 = ivec2(p.ne00, p.ne01);
const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
const vec2 c = (vec2(i10, i11) + p.pixel_offset) / vec2(p.sf0, p.sf1) - p.pixel_offset;
const vec2 c0f = floor(c);
const vec2 d = c - c0f;
const ivec2 c0 = max(ivec2(c0f), 0);
@@ -61,16 +61,6 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
return fetch_bilinear(c0, c1, d, i12, i13);
}
float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
const vec2 c0f = floor(c);
const vec2 d = c - c0f;
const ivec2 c0 = ivec2(c0f);
const ivec2 c1 = c0 + 1;
return fetch_bilinear(c0, c1, d, i12, i13);
}
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
@@ -91,9 +81,6 @@ void main() {
case BILINEAR:
result = interpolate_bilinear(i10, i11, i12, i13);
break;
case BILINEAR | ALIGN_CORNERS:
result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
break;
}
data_d[p.d_offset + idx] = D_TYPE(result);

View File

@@ -317,7 +317,8 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
// disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O";
// disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
std::string opt_level = (coopmat || name.find("bf16") != std::string::npos || name.find("rope") != std::string::npos) ? "" : "-O";
#ifdef _WIN32
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
@@ -566,7 +567,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
// Integer dot mmq performs better with f32 accumulators
if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
@@ -574,7 +576,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
void process_shaders() {
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
// matmul
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
@@ -841,10 +843,14 @@ void process_shaders() {
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
@@ -916,7 +922,8 @@ void process_shaders() {
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});

View File

@@ -221,6 +221,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let is_neox = bool(params.mode & 2);
let is_mrope = bool(params.mode & 8);
let is_imrope = params.mode == 40;
let is_vision = params.mode == 24;
var i = gid.x * 2; // start index for this thread
@@ -248,24 +249,36 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let sec_w = params.sections1 + params.sections0;
let sec_e = params.sections2 + sec_w;
let sector = (i0 / 2) % sect_dims;
if (sector >= params.sections0 && sector < sec_w) {
theta_base_mult = 1;
if (is_vision) {
theta_scale_pwr = sector - params.sections0;
}
} else if (sector >= sec_w && sector < sec_e) {
theta_base_mult = 2;
if (is_vision) {
theta_scale_pwr = sector - sec_w;
}
} else if (sector >= sec_e) {
if (is_vision) {
theta_scale_pwr = sector - sec_e;
theta_scale_pwr = (i0 / 2) % sec_e;
}
theta_base_mult = 3;
} else if (is_vision) {
theta_scale_pwr = sector;
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * params.sections1) {
theta_base_mult = 1;
} else if (sector % 3 == 2 && sector < 3 * params.sections2) {
theta_base_mult = 2;
} else if (sector % 3 == 0 && sector < 3 * params.sections0) {
theta_base_mult = 0;
} else {
theta_base_mult = 3;
}
} else {
if (sector >= params.sections0 && sector < sec_w) {
theta_base_mult = 1;
if (is_vision) {
theta_scale_pwr = sector - params.sections0;
}
} else if (sector >= sec_w && sector < sec_e) {
theta_base_mult = 2;
if (is_vision) {
theta_scale_pwr = sector - sec_w;
}
} else if (sector >= sec_e) {
if (is_vision) {
theta_scale_pwr = sector - sec_e;
theta_scale_pwr = (i0 / 2) % sec_e;
}
theta_base_mult = 3;
} else if (is_vision) {
theta_scale_pwr = sector;
}
}
}
let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));

View File

@@ -111,6 +111,7 @@ class Keys:
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
NUM_DEEPSTACK_LAYERS = "{arch}.n_deepstack_layers"
POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
@@ -277,6 +278,7 @@ class Keys:
USE_GELU = "clip.use_gelu"
USE_SILU = "clip.use_silu"
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
class Attention:
HEAD_COUNT = "clip.vision.attention.head_count"
@@ -350,6 +352,8 @@ class MODEL_ARCH(IntEnum):
QWEN2VL = auto()
QWEN3 = auto()
QWEN3MOE = auto()
QWEN3VL = auto()
QWEN3VLMOE = auto()
PHI2 = auto()
PHI3 = auto()
PHIMOE = auto()
@@ -420,6 +424,7 @@ class MODEL_ARCH(IntEnum):
SEED_OSS = auto()
GROVEMOE = auto()
APERTUS = auto()
COGVLM = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -430,6 +435,8 @@ class VISION_PROJECTOR_TYPE(IntEnum):
GLM_EDGE = auto()
MERGER = auto()
GEMMA3 = auto()
QWEN3VL = auto()
COGVLM = auto()
class MODEL_TENSOR(IntEnum):
@@ -600,6 +607,11 @@ class MODEL_TENSOR(IntEnum):
SHORTCONV_CONV = auto()
SHORTCONV_INPROJ = auto()
SHORTCONV_OUTPROJ = auto()
VISEXP_ATTN_QKV = auto()
VISEXP_ATTN_OUT = auto()
VISEXP_GATE = auto()
VISEXP_DOWN = auto()
VISEXP_UP = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_FC = auto()
@@ -609,6 +621,7 @@ class MODEL_TENSOR(IntEnum):
V_ENC_EMBD_PATCH = auto()
V_ENC_EMBD_POS = auto()
V_ENC_INPUT_NORM = auto()
V_ENC_ATTN_QKV = auto()
V_ENC_ATTN_Q = auto()
V_ENC_ATTN_Q_NORM = auto()
V_ENC_ATTN_K = auto()
@@ -640,6 +653,15 @@ class MODEL_TENSOR(IntEnum):
V_RESMPL_QUERY = auto() # minicpmv
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
V_MM_PATCH_MERGER = auto() # mistral small 3.1
V_DS_NORM = auto() # qwen3vl
V_DS_FC1 = auto() # qwen3vl
V_DS_FC2 = auto() # qwen3vl
V_MM_POST_FC_NORM = auto() # cogvlm
V_MM_UP = auto() # cogvlm
V_MM_DOWN = auto() # cogvlm
V_MM_GATE = auto() # cogvlm
V_TOK_BOI = auto() # cogvlm
V_TOK_EOI = auto() # cogvlm
# audio (mtmd)
A_ENC_EMBD_POS = auto()
A_ENC_CONV1D = auto()
@@ -695,6 +717,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.QWEN3: "qwen3",
MODEL_ARCH.QWEN3MOE: "qwen3moe",
MODEL_ARCH.QWEN3VL: "qwen3vl",
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PHIMOE: "phimoe",
@@ -766,6 +790,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.SEED_OSS: "seed_oss",
MODEL_ARCH.GROVEMOE: "grovemoe",
MODEL_ARCH.APERTUS: "apertus",
MODEL_ARCH.COGVLM: "cogvlm",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -946,6 +971,11 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SHORTCONV_CONV: "blk.{bid}.shortconv.conv",
MODEL_TENSOR.SHORTCONV_INPROJ: "blk.{bid}.shortconv.in_proj",
MODEL_TENSOR.SHORTCONV_OUTPROJ: "blk.{bid}.shortconv.out_proj",
MODEL_TENSOR.VISEXP_ATTN_QKV: "blk.{bid}.vis_attn_qkv",
MODEL_TENSOR.VISEXP_ATTN_OUT: "blk.{bid}.vis_attn_output",
MODEL_TENSOR.VISEXP_GATE: "blk.{bid}.vis_gate",
MODEL_TENSOR.VISEXP_DOWN: "blk.{bid}.vis_down",
MODEL_TENSOR.VISEXP_UP: "blk.{bid}.vis_up",
# vision
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
@@ -954,6 +984,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
MODEL_TENSOR.V_ENC_ATTN_QKV: "v.blk.{bid}.attn_qkv",
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: "v.blk.{bid}.attn_q_norm",
MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k",
@@ -986,6 +1017,15 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
MODEL_TENSOR.V_DS_NORM: "v.deepstack.{bid}.norm",
MODEL_TENSOR.V_DS_FC1: "v.deepstack.{bid}.fc1",
MODEL_TENSOR.V_DS_FC2: "v.deepstack.{bid}.fc2",
MODEL_TENSOR.V_MM_POST_FC_NORM: "mm.post_fc_norm", # cogvlm
MODEL_TENSOR.V_MM_UP: "mm.up",
MODEL_TENSOR.V_MM_DOWN: "mm.down",
MODEL_TENSOR.V_MM_GATE: "mm.gate",
MODEL_TENSOR.V_TOK_BOI: "v.boi",
MODEL_TENSOR.V_TOK_EOI: "v.eoi",
# audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
@@ -1023,6 +1063,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_ATTN_QKV,
MODEL_TENSOR.V_ENC_ATTN_Q,
MODEL_TENSOR.V_ENC_ATTN_Q_NORM,
MODEL_TENSOR.V_ENC_ATTN_K,
@@ -1054,6 +1095,15 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_RESMPL_QUERY,
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
MODEL_TENSOR.V_MM_PATCH_MERGER,
MODEL_TENSOR.V_DS_NORM,
MODEL_TENSOR.V_DS_FC1,
MODEL_TENSOR.V_DS_FC2,
MODEL_TENSOR.V_MM_POST_FC_NORM,
MODEL_TENSOR.V_MM_UP,
MODEL_TENSOR.V_MM_DOWN,
MODEL_TENSOR.V_MM_GATE,
MODEL_TENSOR.V_TOK_BOI,
MODEL_TENSOR.V_TOK_EOI,
# audio
MODEL_TENSOR.A_ENC_EMBD_POS,
MODEL_TENSOR.A_ENC_CONV1D,
@@ -1495,6 +1545,40 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.QWEN3VL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.QWEN3VLMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.PLAMO: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -2837,6 +2921,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_CHEXP,
MODEL_TENSOR.FFN_UP_CHEXP,
],
MODEL_ARCH.COGVLM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.VISEXP_ATTN_QKV,
MODEL_TENSOR.VISEXP_ATTN_OUT,
MODEL_TENSOR.VISEXP_GATE,
MODEL_TENSOR.VISEXP_UP,
MODEL_TENSOR.VISEXP_DOWN,
],
# TODO
}
@@ -3055,6 +3156,7 @@ class VisionProjectorType:
LLAMA4 = "llama4"
QWEN2VL = "qwen2vl_merger"
QWEN25VL = "qwen2.5vl_merger"
QWEN3VL = "qwen3vl_merger"
ULTRAVOX = "ultravox"
INTERNVL = "internvl"
QWEN2A = "qwen2a" # audio
@@ -3062,6 +3164,8 @@ class VisionProjectorType:
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
# Items here are (block size, type size)

View File

@@ -860,6 +860,9 @@ class GGUFWriter:
def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
def add_num_deepstack_layers(self, count: int) -> None:
self.add_uint32(Keys.LLM.NUM_DEEPSTACK_LAYERS.format(arch=self.arch), count)
def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
@@ -1071,6 +1074,9 @@ class GGUFWriter:
def add_vision_n_wa_pattern(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
# audio models
def add_audio_projection_dim(self, value: int) -> None:

View File

@@ -104,6 +104,7 @@ class TensorNameMap:
"backbone.final_layer_norm", # wavtokenizer
"model.norm", # llama4
"model.transformer.ln_f", # llada
"model.norm", # cogvlm
),
# Rope frequencies
@@ -162,6 +163,7 @@ class TensorNameMap:
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
"rwkv.blocks.{bid}.ln2", # rwkv6
"model.layers.{bid}.ln2", # rwkv7
"model.layers.{bid}.post_attention_layernorm", # cogvlm
),
# Attention query-key-value
@@ -184,6 +186,7 @@ class TensorNameMap:
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
"transformer.layers.{bid}.attn.qkv_proj", # openelm
"transformer_encoder.{bid}.qkv", # neobert
"model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm
),
# Attention query
@@ -279,6 +282,7 @@ class TensorNameMap:
"model.transformer.blocks.{bid}.attn_out", # llada
"layers.{bid}.self_attn.o_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.o_proj", # nemotron-h
"model.layers.{bid}.self_attn.language_expert_dense", # cogvlm
),
# Attention output norm
@@ -418,6 +422,7 @@ class TensorNameMap:
"model.transformer.blocks.{bid}.up_proj", # llada
"layers.{bid}.mlp.up_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.up_proj", # nemotron-h
"model.layers.{bid}.mlp.language_mlp.up_proj", # cogvlm
),
MODEL_TENSOR.FFN_UP_EXP: (
@@ -450,21 +455,22 @@ class TensorNameMap:
# Feed-forward gate
MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2
"layers.{bid}.mlp.gate_proj", # embeddinggemma
"layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen
"transformer.h.{bid}.mlp.c_fc2", # jais
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
"model.layers.{bid}.feed_forward.w1", # internlm2
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used)
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
"model.transformer.blocks.{bid}.ff_proj", # llada
"layers.{bid}.mlp.gate_proj", # qwen3-embedding
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2
"layers.{bid}.mlp.gate_proj", # embeddinggemma
"layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen
"transformer.h.{bid}.mlp.c_fc2", # jais
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
"model.layers.{bid}.feed_forward.w1", # internlm2
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used)
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
"model.transformer.blocks.{bid}.ff_proj", # llada
"layers.{bid}.mlp.gate_proj", # qwen3-embedding
"model.layers.{bid}.mlp.language_mlp.gate_proj", # cogvlm
),
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -522,6 +528,7 @@ class TensorNameMap:
"model.transformer.blocks.{bid}.ff_out", # llada
"layers.{bid}.mlp.down_proj", # qwen3-embedding
"backbone.layers.{bid}.mixer.down_proj", # nemotron-h
"model.layers.{bid}.mlp.language_mlp.down_proj", # cogvlm
),
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -1047,6 +1054,26 @@ class TensorNameMap:
"encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5
),
MODEL_TENSOR.VISEXP_UP: (
"model.layers.{bid}.mlp.vision_mlp.up_proj", # cogvlm
),
MODEL_TENSOR.VISEXP_GATE: (
"model.layers.{bid}.mlp.vision_mlp.gate_proj", # cogvlm
),
MODEL_TENSOR.VISEXP_DOWN: (
"model.layers.{bid}.mlp.vision_mlp.down_proj", # cogvlm
),
MODEL_TENSOR.VISEXP_ATTN_OUT: (
"model.layers.{bid}.self_attn.vision_expert_dense", # cogvlm
),
MODEL_TENSOR.VISEXP_ATTN_QKV: (
"model.layers.{bid}.self_attn.vision_expert_query_key_value", # cogvlm
),
############################################################################
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
MODEL_TENSOR.ENC_OUTPUT_NORM: (
@@ -1148,6 +1175,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MMPROJ_FC: (
"model.connector.modality_projection.proj", # SmolVLM
"model.vision.linear_proj.linear_proj", # cogvlm
),
MODEL_TENSOR.V_MMPROJ_MLP: (
@@ -1164,6 +1192,7 @@ class TensorNameMap:
"vision_tower.vision_model.embeddings.class_embedding",
"model.vision_tower.embeddings.cls_token", # Intern-S1
"vision_model.class_embedding", # llama 4
"model.vision.patch_embedding.cls_embedding", # cogvlm
),
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
@@ -1176,6 +1205,7 @@ class TensorNameMap:
"vision_model.patch_embedding.linear", # llama 4
"visual.patch_embed.proj", # qwen2vl
"vision_tower.patch_embed.proj", # kimi-vl
"model.vision.patch_embedding.proj", # cogvlm
),
MODEL_TENSOR.V_ENC_EMBD_POS: (
@@ -1185,6 +1215,13 @@ class TensorNameMap:
"model.vision_model.embeddings.position_embedding", # SmolVLM
"vision_model.positional_embedding_vlm", # llama 4
"vision_tower.patch_embed.pos_emb", # kimi-vl
"visual.pos_embed", # qwen3vl
"model.vision.patch_embedding.position_embedding", # cogvlm
),
MODEL_TENSOR.V_ENC_ATTN_QKV: (
"visual.blocks.{bid}.attn.qkv", # qwen3vl
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
),
MODEL_TENSOR.V_ENC_ATTN_Q: (
@@ -1244,6 +1281,7 @@ class TensorNameMap:
"vision_model.model.layers.{bid}.input_layernorm", # llama4
"visual.blocks.{bid}.norm1", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
),
MODEL_TENSOR.V_ENC_ATTN_O: (
@@ -1257,6 +1295,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
"visual.blocks.{bid}.attn.proj", # qwen2vl
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
"model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
),
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
@@ -1270,6 +1309,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
"visual.blocks.{bid}.norm2", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
),
MODEL_TENSOR.V_ENC_FFN_UP: (
@@ -1282,7 +1322,9 @@ class TensorNameMap:
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
"visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
),
MODEL_TENSOR.V_ENC_FFN_GATE: (
@@ -1301,7 +1343,9 @@ class TensorNameMap:
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
"visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
),
MODEL_TENSOR.V_LAYER_SCALE_1: (
@@ -1338,6 +1382,7 @@ class TensorNameMap:
"multi_modal_projector.layer_norm",
"multi_modal_projector.pre_norm",
"pre_mm_projector_norm",
"model.vision.linear_proj.norm1", # cogvlm
),
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
@@ -1397,6 +1442,42 @@ class TensorNameMap:
"patch_merger.merging_layer", # mistral
),
MODEL_TENSOR.V_DS_NORM: (
"model.visual.deepstack_merger_list.{bid}.norm", # deepstack in qwen3vl
),
MODEL_TENSOR.V_DS_FC1: (
"model.visual.deepstack_merger_list.{bid}.linear_fc1", # deepstack in qwen3vl
),
MODEL_TENSOR.V_DS_FC2: (
"model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl
),
MODEL_TENSOR.V_MM_POST_FC_NORM: (
"model.vision.linear_proj.norm1", # cogvlm
),
MODEL_TENSOR.V_MM_UP: (
"model.vision.linear_proj.dense_h_to_4h", # cogvlm
),
MODEL_TENSOR.V_MM_DOWN: (
"model.vision.linear_proj.dense_4h_to_h", # cogvlm
),
MODEL_TENSOR.V_MM_GATE: (
"model.vision.linear_proj.gate_proj", # cogvlm
),
MODEL_TENSOR.V_TOK_BOI: (
"model.vision.boi", # cogvlm
),
MODEL_TENSOR.V_TOK_EOI: (
"model.vision.eoi", # cogvlm
),
# audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: (

View File

@@ -14,12 +14,12 @@ except ImportError:
SentencePieceProcessor = None
try:
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from mistral_common.tokens.tokenizers.utils import (
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
_filter_valid_tokenizer_files,
)
from mistral_common.tokens.tokenizers.sentencepiece import (
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
SentencePieceTokenizer,
)
except ImportError:

View File

@@ -83,6 +83,7 @@ extern "C" {
LLAMA_ROPE_TYPE_NORM = 0,
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE,
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
};

View File

@@ -0,0 +1,37 @@
{{- bos_token -}}
{%- set system_prompt = "" -%}
{%- set ns = namespace(system_prompt="") -%}
{%- if messages[0]["role"] == "system" -%}
{%- set ns.system_prompt = messages[0]["content"] -%}
{%- set messages = messages[1:] -%}
{%- endif -%}
{%- if tools -%}
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%}
{%- for tool in tools -%}
{%- if tool is not string -%}
{%- set tool = tool | tojson -%}
{%- endif -%}
{%- set ns.system_prompt = ns.system_prompt + tool -%}
{%- if not loop.last -%}
{%- set ns.system_prompt = ns.system_prompt + ", " -%}
{%- endif -%}
{%- endfor -%}
{%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%}
{%- endif -%}
{%- if ns.system_prompt -%}
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
{%- endif -%}
{%- for message in messages -%}
{{- "<|im_start|>" + message["role"] + "\n" -}}
{%- set content = message["content"] -%}
{%- if content is not string -%}
{%- set content = content | tojson -%}
{%- endif -%}
{%- if message["role"] == "tool" -%}
{%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%}
{%- endif -%}
{{- content + "<|im_end|>\n" -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}

View File

@@ -1,5 +1,3 @@
mistral-common>=1.8.3
-r ./requirements-convert_legacy_llama.txt
--extra-index-url https://download.pytorch.org/whl/cpu

View File

@@ -1,14 +1,7 @@
numpy~=1.26.4
sentencepiece~=0.2.0
# Embedding Gemma is currently a preview release:
# https://github.com/huggingface/transformers/releases/tag/v4.56.0-Embedding-Gemma-preview
# The version is needed to be able to convert Embedding Gemma models to GGUF format:
git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
# Once Embedding Gemma is officially released, we can switch to:
#transformers>=4.57.1,<5.0.0
transformers>=4.57.1,<5.0.0
gguf>=0.1.0
protobuf>=4.21.0,<5.0.0

View File

@@ -35,5 +35,6 @@ adb $adbserial shell " \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
-t 4 --batch-size 128 -ngl 99 $@ \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--batch-size 128 -ngl 99 $@ \
"

View File

@@ -45,8 +45,9 @@ adb $adbserial shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev \
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
-t 4 --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
$verbose $experimental $sched $opmask $profile $nhvx $ndev \
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
-ngl 99 --device $device $cli_opts $@ \
"

Some files were not shown because too many files have changed in this diff Show More