mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-17 06:24:07 +00:00
Compare commits
109 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66d65ec29b | ||
|
|
05728db18e | ||
|
|
4720819d45 | ||
|
|
d979f2b176 | ||
|
|
ecbcb7ea9d | ||
|
|
3e6ab244ad | ||
|
|
5596a35791 | ||
|
|
8d3b962f47 | ||
|
|
d903f30e25 | ||
|
|
8387ffb28d | ||
|
|
2e7e638523 | ||
|
|
a8b192b6ec | ||
|
|
c17dce4f5c | ||
|
|
88cf781f51 | ||
|
|
4e76d24f28 | ||
|
|
723c71064d | ||
|
|
37964f44f9 | ||
|
|
01cd448b8c | ||
|
|
99bd67c9b2 | ||
|
|
b68d75165a | ||
|
|
ffaafde16f | ||
|
|
efba35a860 | ||
|
|
9b62913b40 | ||
|
|
66287bdaac | ||
|
|
1ca3d1de15 | ||
|
|
bd72300591 | ||
|
|
2943210c1e | ||
|
|
3769fe6eb7 | ||
|
|
832aa94762 | ||
|
|
3af34b9ff5 | ||
|
|
f20469d919 | ||
|
|
d7d826b3c1 | ||
|
|
c747294b2d | ||
|
|
8fdf269dad | ||
|
|
a96a1120b4 | ||
|
|
244641955f | ||
|
|
47eb12b953 | ||
|
|
418dea39ce | ||
|
|
da426cb250 | ||
|
|
c830f99cfa | ||
|
|
aa6f918c1c | ||
|
|
8c2c0108dd | ||
|
|
3ea5360c00 | ||
|
|
39fb81f875 | ||
|
|
5eb0ea32f0 | ||
|
|
b68a83e641 | ||
|
|
d8aeb65cee | ||
|
|
9051663d5d | ||
|
|
72b44c0d21 | ||
|
|
bc160d3582 | ||
|
|
2b6dfe824d | ||
|
|
e8e261699a | ||
|
|
5452d736f8 | ||
|
|
ed4837891d | ||
|
|
cacc371f99 | ||
|
|
ae2368e74e | ||
|
|
9f0684f003 | ||
|
|
34ec1c3f18 | ||
|
|
e877ad8bd9 | ||
|
|
35715657cb | ||
|
|
f75c4e8bf5 | ||
|
|
99156f3a5f | ||
|
|
a0c91e8f9f | ||
|
|
07968d53e4 | ||
|
|
ba3b9c8844 | ||
|
|
94b0200a01 | ||
|
|
b908baf182 | ||
|
|
492bc31978 | ||
|
|
77d6ae4ac8 | ||
|
|
10b26ee23a | ||
|
|
3dadc88b58 | ||
|
|
39e4b1dc9b | ||
|
|
11c325c6e0 | ||
|
|
237958db33 | ||
|
|
abb9f3c42b | ||
|
|
da348c9dfb | ||
|
|
e6267a9359 | ||
|
|
2bf318fd2f | ||
|
|
c78e682245 | ||
|
|
c5897995a7 | ||
|
|
03fd9d3bb4 | ||
|
|
8004f3a8d1 | ||
|
|
eacb4b67a2 | ||
|
|
c0d0430340 | ||
|
|
3bb2fcc856 | ||
|
|
27326bfce1 | ||
|
|
ad9f692f8f | ||
|
|
8a70973557 | ||
|
|
e7f2f95c9a | ||
|
|
b55dcdef5d | ||
|
|
eeef3cfced | ||
|
|
e99f1083a0 | ||
|
|
238856ec8f | ||
|
|
ea003229d3 | ||
|
|
d0061be838 | ||
|
|
a569bda445 | ||
|
|
e2f19b320f | ||
|
|
983559d24b | ||
|
|
2b089c7758 | ||
|
|
afa6bfe4f7 | ||
|
|
ae2d3f28a8 | ||
|
|
ad8207af77 | ||
|
|
667b694278 | ||
|
|
e48349a49d | ||
|
|
ae46a61e41 | ||
|
|
65cede7c70 | ||
|
|
05fa625eac | ||
|
|
d612901116 | ||
|
|
cceb1b4e33 |
@@ -1,8 +1,8 @@
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG ROCM_VERSION=7.0
|
||||
ARG AMDGPU_VERSION=7.0
|
||||
ARG ROCM_VERSION=7.2
|
||||
ARG AMDGPU_VERSION=7.2
|
||||
|
||||
# Target the ROCm build image
|
||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
@@ -11,13 +11,12 @@ ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-co
|
||||
FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
# List from https://github.com/ggml-org/llama.cpp/pull/1087#issuecomment-1682807878
|
||||
# This is mostly tied to rocBLAS supported archs.
|
||||
# gfx803, gfx900, gfx906, gfx1032, gfx1101, gfx1102,not officialy supported
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.0/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityrad/native_linux/native_linux_compatibility.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityryz/native_linux/native_linux_compatibility.html
|
||||
|
||||
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
|
||||
#ARG ROCM_DOCKER_ARCH='gfx1151'
|
||||
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201'
|
||||
|
||||
# Set ROCm architectures
|
||||
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
|
||||
|
||||
@@ -11,5 +11,5 @@ runs:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/install-exe
|
||||
with:
|
||||
url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-WinSvr2022-For-HIP.exe
|
||||
url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-Win11-For-HIP.exe
|
||||
args: -install
|
||||
|
||||
2
.github/workflows/build-cache.yml
vendored
2
.github/workflows/build-cache.yml
vendored
@@ -68,7 +68,7 @@ jobs:
|
||||
|
||||
env:
|
||||
# Make sure this is in sync with build.yml
|
||||
HIPSDK_INSTALLER_VERSION: "25.Q3"
|
||||
HIPSDK_INSTALLER_VERSION: "26.Q1"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
8
.github/workflows/build.yml
vendored
8
.github/workflows/build.yml
vendored
@@ -1175,10 +1175,8 @@ jobs:
|
||||
runs-on: windows-2022
|
||||
|
||||
env:
|
||||
# The ROCm version must correspond to the version used in the HIP SDK.
|
||||
ROCM_VERSION: "6.4.2"
|
||||
# Make sure this is in sync with build-cache.yml
|
||||
HIPSDK_INSTALLER_VERSION: "25.Q3"
|
||||
HIPSDK_INSTALLER_VERSION: "26.Q1"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -1188,7 +1186,7 @@ jobs:
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
run: |
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/${{ env.ROCM_VERSION }}/pool/main/r/rocwmma-dev/rocwmma-dev_1.7.0.60402-120~24.04_amd64.deb"
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
|
||||
7z x rocwmma.deb
|
||||
7z x data.tar
|
||||
|
||||
@@ -1231,7 +1229,7 @@ jobs:
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DLLAMA_BUILD_BORINGSSL=ON `
|
||||
-DROCM_DIR="${env:HIP_PATH}" `
|
||||
|
||||
2
.github/workflows/gguf-publish.yml
vendored
2
.github/workflows/gguf-publish.yml
vendored
@@ -21,7 +21,7 @@ on:
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-slim
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
114
.github/workflows/release.yml
vendored
114
.github/workflows/release.yml
vendored
@@ -516,17 +516,113 @@ jobs:
|
||||
path: llama-bin-win-sycl-x64.zip
|
||||
name: llama-bin-win-sycl-x64.zip
|
||||
|
||||
ubuntu-22-rocm:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- ROCM_VERSION: "7.2"
|
||||
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201"
|
||||
build: 'x64'
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-rocm-cmake-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt install -y build-essential git cmake wget
|
||||
|
||||
- name: Setup Legacy ROCm
|
||||
if: matrix.ROCM_VERSION == '7.2'
|
||||
id: legacy_env
|
||||
run: |
|
||||
sudo mkdir --parents --mode=0755 /etc/apt/keyrings
|
||||
wget https://repo.radeon.com/rocm/rocm.gpg.key -O - | \
|
||||
gpg --dearmor | sudo tee /etc/apt/keyrings/rocm.gpg > /dev/null
|
||||
|
||||
sudo tee /etc/apt/sources.list.d/rocm.list << EOF
|
||||
deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/${{ matrix.ROCM_VERSION }} jammy main
|
||||
EOF
|
||||
|
||||
sudo tee /etc/apt/preferences.d/rocm-pin-600 << EOF
|
||||
Package: *
|
||||
Pin: release o=repo.radeon.com
|
||||
Pin-Priority: 600
|
||||
EOF
|
||||
|
||||
sudo apt update
|
||||
sudo apt-get install -y libssl-dev rocm-hip-sdk
|
||||
|
||||
- name: Setup TheRock
|
||||
if: matrix.ROCM_VERSION != '7.2'
|
||||
id: therock_env
|
||||
run: |
|
||||
wget https://repo.amd.com/rocm/tarball/therock-dist-linux-gfx1151-${{ matrix.ROCM_VERSION }}.tar.gz
|
||||
mkdir install
|
||||
tar -xf *.tar.gz -C install
|
||||
export ROCM_PATH=$(pwd)/install
|
||||
echo ROCM_PATH=$ROCM_PATH >> $GITHUB_ENV
|
||||
echo PATH=$PATH:$ROCM_PATH/bin >> $GITHUB_ENV
|
||||
echo LD_LIBRARY_PATH=$ROCM_PATH/lib:$ROCM_PATH/llvm/lib:$ROCM_PATH/lib/rocprofiler-systems >> $GITHUB_ENV
|
||||
|
||||
- name: Build with native CMake HIP support
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DCMAKE_HIP_FLAGS="-mllvm --amdgpu-unroll-threshold-local=600" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
|
||||
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DGPU_TARGETS="${{ matrix.gpu_targets }}" \
|
||||
-DGGML_HIP=ON \
|
||||
-DHIP_PLATFORM=amd \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
${{ env.CMAKE_ARGS }}
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
|
||||
windows-hip:
|
||||
runs-on: windows-2022
|
||||
|
||||
env:
|
||||
HIPSDK_INSTALLER_VERSION: "25.Q3"
|
||||
HIPSDK_INSTALLER_VERSION: "26.Q1"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- name: "radeon"
|
||||
gpu_targets: "gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
|
||||
gpu_targets: "gfx1150;gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -536,7 +632,7 @@ jobs:
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
run: |
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.0.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.0.0.70001-42~24.04_amd64.deb"
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
|
||||
7z x rocwmma.deb
|
||||
7z x data.tar
|
||||
|
||||
@@ -559,7 +655,7 @@ jobs:
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "Downloading AMD HIP SDK Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-Win11-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP SDK"
|
||||
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
|
||||
$completed = $proc.WaitForExit(600000)
|
||||
@@ -593,20 +689,20 @@ jobs:
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.0.1/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DGGML_BACKEND_DL=ON `
|
||||
-DGGML_NATIVE=OFF `
|
||||
-DGGML_CPU=OFF `
|
||||
-DAMDGPU_TARGETS="${{ matrix.gpu_targets }}" `
|
||||
-DGPU_TARGETS="${{ matrix.gpu_targets }}" `
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON `
|
||||
-DGGML_HIP=ON `
|
||||
-DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS}
|
||||
md "build\bin\rocblas\library\"
|
||||
md "build\bin\hipblaslt\library"
|
||||
cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
|
||||
cp "${env:HIP_PATH}\bin\hipblaslt.dll" "build\bin\"
|
||||
cp "${env:HIP_PATH}\bin\libhipblas.dll" "build\bin\"
|
||||
cp "${env:HIP_PATH}\bin\libhipblaslt.dll" "build\bin\"
|
||||
cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\"
|
||||
cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\"
|
||||
cp "${env:HIP_PATH}\bin\hipblaslt\library\*" "build\bin\hipblaslt\library\"
|
||||
@@ -784,6 +880,7 @@ jobs:
|
||||
- windows-cuda
|
||||
- windows-sycl
|
||||
- windows-hip
|
||||
- ubuntu-22-rocm
|
||||
- ubuntu-22-cpu
|
||||
- ubuntu-22-vulkan
|
||||
- macOS-arm64
|
||||
@@ -868,6 +965,7 @@ jobs:
|
||||
**Linux:**
|
||||
- [Ubuntu x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.tar.gz)
|
||||
- [Ubuntu x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz)
|
||||
- [Ubuntu x64 (ROCm 7.2)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-7.2-x64.tar.gz)
|
||||
- [Ubuntu s390x (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-s390x.tar.gz)
|
||||
|
||||
**Windows:**
|
||||
|
||||
2
.github/workflows/winget.yml
vendored
2
.github/workflows/winget.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
|
||||
- name: Install komac
|
||||
run: |
|
||||
cargo binstall komac@2.11.2 -y
|
||||
cargo binstall komac@2.15.0 -y
|
||||
|
||||
- name: Find latest release
|
||||
id: find_latest_release
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
|
||||
cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories.
|
||||
project("llama.cpp" C CXX)
|
||||
include(CheckIncludeFileCXX)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ find_package(Threads REQUIRED)
|
||||
llama_add_compile_flags()
|
||||
|
||||
# Build info header
|
||||
#
|
||||
|
||||
if(EXISTS "${PROJECT_SOURCE_DIR}/.git")
|
||||
set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git")
|
||||
@@ -110,29 +109,16 @@ if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS build_info)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
build_info
|
||||
cpp-httplib
|
||||
)
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
include(ExternalProject)
|
||||
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
|
||||
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
|
||||
|
||||
# Set the correct library file extension based on platform
|
||||
if (WIN32)
|
||||
set(LLGUIDANCE_LIB_NAME "llguidance.lib")
|
||||
# Add Windows-specific libraries
|
||||
set(LLGUIDANCE_PLATFORM_LIBS
|
||||
ws2_32 # Windows Sockets API
|
||||
userenv # For GetUserProfileDirectoryW
|
||||
ntdll # For NT functions
|
||||
bcrypt # For BCryptGenRandom
|
||||
)
|
||||
else()
|
||||
set(LLGUIDANCE_LIB_NAME "libllguidance.a")
|
||||
set(LLGUIDANCE_PLATFORM_LIBS "")
|
||||
endif()
|
||||
set(LLGUIDANCE_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}llguidance${CMAKE_STATIC_LIBRARY_SUFFIX}")
|
||||
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
@@ -154,8 +140,10 @@ if (LLAMA_LLGUIDANCE)
|
||||
add_dependencies(llguidance llguidance_ext)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
|
||||
# Add platform libraries to the main target
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
|
||||
endif ()
|
||||
target_link_libraries(${TARGET} PRIVATE llguidance)
|
||||
if (WIN32)
|
||||
target_link_libraries(${TARGET} PRIVATE ws2_32 userenv ntdll bcrypt)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
|
||||
target_link_libraries(${TARGET} PUBLIC llama Threads::Threads)
|
||||
|
||||
@@ -1578,7 +1578,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--temp"}, "N",
|
||||
{"--temp", "--temperature"}, "N",
|
||||
string_format("temperature (default: %.2f)", (double)params.sampling.temp),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.temp = std::stof(value);
|
||||
@@ -1611,7 +1611,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--top-nsigma"}, "N",
|
||||
{"--top-nsigma", "--top-n-sigma"}, "N",
|
||||
string_format("top-n-sigma sampling (default: %.2f, -1.0 = disabled)", params.sampling.top_n_sigma),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.top_n_sigma = std::stof(value);
|
||||
@@ -1634,7 +1634,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--typical"}, "N",
|
||||
{"--typical", "--typical-p"}, "N",
|
||||
string_format("locally typical sampling, parameter p (default: %.2f, 1.0 = disabled)", (double)params.sampling.typ_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.typ_p = std::stof(value);
|
||||
@@ -2520,11 +2520,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"-a", "--alias"}, "STRING",
|
||||
"set alias for model name (to be used by REST API)",
|
||||
"set model name aliases, comma-separated (to be used by API)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model_alias = value;
|
||||
for (auto & alias : string_split<std::string>(value, ',')) {
|
||||
alias = string_strip(alias);
|
||||
if (!alias.empty()) {
|
||||
params.model_alias.insert(alias);
|
||||
}
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ALIAS"));
|
||||
add_opt(common_arg(
|
||||
{"--tags"}, "STRING",
|
||||
"set model tags, comma-separated (informational, not used for routing)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (auto & tag : string_split<std::string>(value, ',')) {
|
||||
tag = string_strip(tag);
|
||||
if (!tag.empty()) {
|
||||
params.model_tags.insert(tag);
|
||||
}
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TAGS"));
|
||||
add_opt(common_arg(
|
||||
{"-m", "--model"}, "FNAME",
|
||||
ex == LLAMA_EXAMPLE_EXPORT_LORA
|
||||
|
||||
@@ -803,7 +803,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons
|
||||
}
|
||||
|
||||
// remove potential partial suffix
|
||||
if (builder.pos() == builder.input().size()) {
|
||||
if (builder.pos() == builder.input().size() && builder.is_partial()) {
|
||||
if (unclosed_reasoning_content.empty()) {
|
||||
rstrip(content);
|
||||
trim_potential_partial_word(content);
|
||||
|
||||
@@ -893,23 +893,6 @@ static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_call>";
|
||||
form.tool_start = "<function=";
|
||||
form.tool_sep = ">";
|
||||
form.key_start = "<parameter=";
|
||||
form.key_val_sep = ">";
|
||||
form.val_end = "</parameter>";
|
||||
form.tool_end = "</function>";
|
||||
form.scope_end = "</tool_call>";
|
||||
form.trim_raw_argval = true;
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
@@ -1590,9 +1573,6 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||
common_chat_parse_kimi_k2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
|
||||
common_chat_parse_qwen3_coder_xml(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5:
|
||||
common_chat_parse_apriel_1_5(builder);
|
||||
break;
|
||||
|
||||
@@ -65,14 +65,25 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
} else if (!content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
bool last_was_media_marker = false;
|
||||
// join parts with newline, do not add newline before or after media markers
|
||||
for (const auto & part : content_parts) {
|
||||
if (part.type != "text") {
|
||||
bool add_new_line = true;
|
||||
if (part.type == "text") {
|
||||
add_new_line = !last_was_media_marker && !text.empty();
|
||||
last_was_media_marker = false;
|
||||
} else if (part.type == "media_marker") {
|
||||
add_new_line = false;
|
||||
last_was_media_marker = true;
|
||||
} else {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
|
||||
if (add_new_line) {
|
||||
text += '\n';
|
||||
}
|
||||
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
@@ -319,7 +330,7 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
||||
throw std::invalid_argument("Missing content part type: " + part.dump());
|
||||
}
|
||||
const auto & type = part.at("type");
|
||||
if (type != "text") {
|
||||
if (type != "text" && type != "media_marker") {
|
||||
throw std::invalid_argument("Unsupported content part type: " + type.dump());
|
||||
}
|
||||
common_chat_msg_content_part msg_part;
|
||||
@@ -725,7 +736,6 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
|
||||
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2";
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
|
||||
case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open";
|
||||
@@ -1511,14 +1521,17 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
static common_chat_params common_chat_params_init_qwen3_coder(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED;
|
||||
|
||||
// Nemotron Nano 3 and Step-3.5-Flash use the Qwen3 Coder tool calling with thinking
|
||||
bool supports_reasoning = (tmpl.source().find("<think>") != std::string::npos);
|
||||
|
||||
// Handle thinking tags appropriately based on inputs.enable_thinking
|
||||
if (string_ends_with(data.prompt, "<think>\n")) {
|
||||
if (supports_reasoning && string_ends_with(data.prompt, "<think>\n")) {
|
||||
if (!inputs.enable_thinking) {
|
||||
data.prompt += "</think>";
|
||||
} else {
|
||||
@@ -1527,19 +1540,21 @@ static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_
|
||||
}
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
};
|
||||
|
||||
if (supports_reasoning) {
|
||||
data.preserved_tokens.insert(data.preserved_tokens.end(), {"<think>", "</think>"});
|
||||
}
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = true;
|
||||
|
||||
auto parser = build_chat_peg_constructed_parser([&](auto & p) {
|
||||
auto reasoning = p.eps();
|
||||
if (inputs.enable_thinking && extract_reasoning) {
|
||||
if (supports_reasoning && inputs.enable_thinking && extract_reasoning) {
|
||||
auto reasoning_content = p.reasoning(p.until("</think>")) + ("</think>" | p.end());
|
||||
if (data.thinking_forced_open) {
|
||||
reasoning = reasoning_content;
|
||||
@@ -1877,38 +1892,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
"<function=",
|
||||
"</function>",
|
||||
"<parameter=",
|
||||
"</parameter>",
|
||||
};
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<tool_call>\n",
|
||||
/* form.tool_start = */ "<function=",
|
||||
/* form.tool_sep = */ ">\n",
|
||||
/* form.key_start = */ "<parameter=",
|
||||
/* form.key_val_sep = */ ">\n",
|
||||
/* form.val_end = */ "\n</parameter>\n",
|
||||
/* form.tool_end = */ "</function>\n",
|
||||
/* form.scope_end = */ "</tool_call>",
|
||||
};
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
@@ -2032,6 +2015,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
if (has_reasoning_content && has_tool_calls) {
|
||||
auto adjusted_message = msg;
|
||||
adjusted_message["thinking"] = msg.at("reasoning_content");
|
||||
adjusted_message.erase("content");
|
||||
adjusted_messages.push_back(adjusted_message);
|
||||
} else {
|
||||
adjusted_messages.push_back(msg);
|
||||
@@ -3129,19 +3113,13 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
}
|
||||
|
||||
// Qwen3-Coder XML format detection (must come before Hermes 2 Pro)
|
||||
// Detect via explicit XML markers unique to Qwen3-Coder to avoid false positives in other templates.
|
||||
// Require presence of <tool_call>, <function=...>, and <parameter=...> blocks.
|
||||
// Detect via XML markers: <tool_call>, <function=...>, and <parameter=...> blocks.
|
||||
// Also matches Step-3.5-Flash and Nemotron 3 Nano which use the same output format.
|
||||
if (src.find("<tool_call>") != std::string::npos &&
|
||||
src.find("<function>") != std::string::npos &&
|
||||
src.find("<function=") != std::string::npos &&
|
||||
src.find("<parameters>") != std::string::npos &&
|
||||
src.find("<parameter=") != std::string::npos) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
// Nemotron 3 Nano 30B A3B
|
||||
if (src.find("<think>") != std::string::npos) {
|
||||
return common_chat_params_init_nemotron_v3(tmpl, params);
|
||||
}
|
||||
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
|
||||
return common_chat_params_init_qwen3_coder(tmpl, params);
|
||||
}
|
||||
|
||||
// Xiaomi MiMo format detection (must come before Hermes 2 Pro)
|
||||
@@ -3307,7 +3285,7 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto content = msg.content;
|
||||
for (const auto & part : msg.content_parts) {
|
||||
if (part.type != "text") {
|
||||
if (part.type != "text" && part.type != "media_marker") {
|
||||
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -128,7 +128,6 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_GLM_4_5,
|
||||
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
||||
COMMON_CHAT_FORMAT_KIMI_K2,
|
||||
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
|
||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||
COMMON_CHAT_FORMAT_SOLAR_OPEN,
|
||||
|
||||
@@ -452,34 +452,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
|
||||
bool has_suffix = string_ends_with(str, suffix);
|
||||
if (has_suffix) {
|
||||
str = str.substr(0, str.size() - suffix.size());
|
||||
}
|
||||
return has_suffix;
|
||||
}
|
||||
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const char text_last_char = str.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const auto current_partial = stop.substr(0, char_index + 1);
|
||||
if (string_ends_with(str, current_partial)) {
|
||||
return str.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string & s) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
return std::regex_replace(s, special_chars, "\\$&");
|
||||
@@ -1788,3 +1760,65 @@ float lr_opt::get_lr(float epoch) const {
|
||||
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
|
||||
return r;
|
||||
}
|
||||
|
||||
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) {
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
batch.pos = &pos;
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s: failed to replay last token\n", __func__);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool common_prompt_batch_decode(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & tokens,
|
||||
int & n_past,
|
||||
int n_batch,
|
||||
std::string_view state_path,
|
||||
bool save_state) {
|
||||
const int n_eval = tokens.size();
|
||||
if (n_eval == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (save_state && n_eval > 1) {
|
||||
const int n_tokens_before_last = n_eval - 1;
|
||||
|
||||
GGML_ASSERT(n_eval <= n_batch);
|
||||
|
||||
// Decode all but the last token so we can save the memory state before decoding the last token.
|
||||
// This is done so we can restore the session state later and replay the last token.
|
||||
// Memory implementations in recurrent/hybrid models don't support removing tokens from their
|
||||
// memory, so we can't just remove the last token from the memory and replay the last token which
|
||||
// is the reason for this logic.
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past += n_tokens_before_last;
|
||||
|
||||
llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last);
|
||||
LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last);
|
||||
|
||||
llama_token last_token = tokens.back();
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
int32_t pos = n_past;
|
||||
batch.pos = &pos;
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval last token\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past++;
|
||||
} else {
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past += n_eval;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -410,7 +410,8 @@ struct common_params {
|
||||
|
||||
struct common_params_model model;
|
||||
|
||||
std::string model_alias = ""; // model alias // NOLINT
|
||||
std::set<std::string> model_alias; // model aliases // NOLINT
|
||||
std::set<std::string> model_tags; // model tags (informational, not used for routing) // NOLINT
|
||||
std::string hf_token = ""; // HF token // NOLINT
|
||||
std::string prompt = ""; // NOLINT
|
||||
std::string system_prompt = ""; // NOLINT
|
||||
@@ -670,30 +671,55 @@ static std::vector<T> string_split(const std::string & str, char delim) {
|
||||
}
|
||||
|
||||
template<>
|
||||
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
||||
inline std::vector<std::string> string_split<std::string>(const std::string & str, char delim)
|
||||
{
|
||||
std::vector<std::string> parts;
|
||||
size_t begin_pos = 0;
|
||||
size_t separator_pos = input.find(separator);
|
||||
while (separator_pos != std::string::npos) {
|
||||
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
|
||||
size_t delim_pos = str.find(delim);
|
||||
while (delim_pos != std::string::npos) {
|
||||
std::string part = str.substr(begin_pos, delim_pos - begin_pos);
|
||||
parts.emplace_back(part);
|
||||
begin_pos = separator_pos + 1;
|
||||
separator_pos = input.find(separator, begin_pos);
|
||||
begin_pos = delim_pos + 1;
|
||||
delim_pos = str.find(delim, begin_pos);
|
||||
}
|
||||
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
|
||||
parts.emplace_back(str.substr(begin_pos));
|
||||
return parts;
|
||||
}
|
||||
|
||||
static bool string_starts_with(const std::string & str,
|
||||
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
// remove when moving to c++20
|
||||
inline bool string_starts_with(std::string_view str, std::string_view prefix) {
|
||||
return str.size() >= prefix.size() &&
|
||||
str.compare(0, prefix.size(), prefix) == 0;
|
||||
}
|
||||
|
||||
// While we wait for C++20's std::string::ends_with...
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||
// remove when moving to c++20
|
||||
inline bool string_ends_with(std::string_view str, std::string_view suffix) {
|
||||
return str.size() >= suffix.size() &&
|
||||
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
inline bool string_remove_suffix(std::string & str, std::string_view suffix) {
|
||||
if (string_ends_with(str, suffix)) {
|
||||
str.resize(str.size() - suffix.size());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline size_t string_find_partial_stop(std::string_view str, std::string_view stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const size_t max_len = std::min(str.size(), stop.size());
|
||||
const char last_char = str.back();
|
||||
for (size_t len = max_len; len > 0; --len) {
|
||||
if (stop[len - 1] == last_char) {
|
||||
if (string_ends_with(str, stop.substr(0, len))) {
|
||||
return str.size() - len;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||
void string_process_escapes(std::string & input);
|
||||
@@ -779,6 +805,23 @@ void common_batch_add(
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits);
|
||||
|
||||
// decodes a single batch of tokens for a prompt and manages session tokens
|
||||
//
|
||||
// Note: We save state before the last token so that we can replay it to ensure
|
||||
// compatibility with all memory types. Recurrent/hybrid models cannot remove
|
||||
// tokens from memory, so this approach works across all model architectures.
|
||||
bool common_prompt_batch_decode(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & embd,
|
||||
int & n_past,
|
||||
int n_batch,
|
||||
std::string_view state_path,
|
||||
bool save_state);
|
||||
|
||||
// replays the last token after loading state to regenerate logits
|
||||
// used after loading session state to ensure the sampling context has valid logits
|
||||
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
@@ -870,11 +913,11 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||
|
||||
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
|
||||
|
||||
static std::string llm_ffn_exps_block_regex(int idx) {
|
||||
inline std::string llm_ffn_exps_block_regex(int idx) {
|
||||
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
|
||||
}
|
||||
|
||||
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
||||
inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
||||
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
|
||||
}
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ value identifier::execute_impl(context & ctx) {
|
||||
auto builtins = global_builtins();
|
||||
if (!it->is_undefined()) {
|
||||
if (ctx.is_get_stats) {
|
||||
it->stats.used = true;
|
||||
value_t::stats_t::mark_used(it);
|
||||
}
|
||||
JJ_DEBUG("Identifier '%s' found, type = %s", val.c_str(), it->type().c_str());
|
||||
return it;
|
||||
@@ -277,7 +277,7 @@ value binary_expression::execute_impl(context & ctx) {
|
||||
static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) {
|
||||
JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str());
|
||||
if (ctx.is_get_stats) {
|
||||
input->stats.used = true;
|
||||
value_t::stats_t::mark_used(input);
|
||||
input->stats.ops.insert(name);
|
||||
}
|
||||
auto builtins = input->get_builtins();
|
||||
@@ -448,7 +448,7 @@ value for_statement::execute_impl(context & ctx) {
|
||||
|
||||
// mark the variable being iterated as used for stats
|
||||
if (ctx.is_get_stats) {
|
||||
iterable_val->stats.used = true;
|
||||
value_t::stats_t::mark_used(iterable_val);
|
||||
iterable_val->stats.ops.insert("array_access");
|
||||
}
|
||||
|
||||
@@ -470,7 +470,7 @@ value for_statement::execute_impl(context & ctx) {
|
||||
items.push_back(std::move(tuple));
|
||||
}
|
||||
if (ctx.is_get_stats) {
|
||||
iterable_val->stats.used = true;
|
||||
value_t::stats_t::mark_used(iterable_val);
|
||||
iterable_val->stats.ops.insert("object_access");
|
||||
}
|
||||
} else {
|
||||
@@ -480,7 +480,7 @@ value for_statement::execute_impl(context & ctx) {
|
||||
items.push_back(item);
|
||||
}
|
||||
if (ctx.is_get_stats) {
|
||||
iterable_val->stats.used = true;
|
||||
value_t::stats_t::mark_used(iterable_val);
|
||||
iterable_val->stats.ops.insert("array_access");
|
||||
}
|
||||
}
|
||||
@@ -721,6 +721,8 @@ value member_expression::execute_impl(context & ctx) {
|
||||
int64_t arr_size = 0;
|
||||
if (is_val<value_array>(object)) {
|
||||
arr_size = object->as_array().size();
|
||||
} else if (is_val<value_string>(object)) {
|
||||
arr_size = object->as_string().length();
|
||||
}
|
||||
|
||||
if (is_stmt<slice_expression>(this->property)) {
|
||||
@@ -817,8 +819,9 @@ value member_expression::execute_impl(context & ctx) {
|
||||
}
|
||||
|
||||
if (ctx.is_get_stats && val && object && property) {
|
||||
val->stats.used = true;
|
||||
object->stats.used = true;
|
||||
value_t::stats_t::mark_used(val);
|
||||
value_t::stats_t::mark_used(object);
|
||||
value_t::stats_t::mark_used(property);
|
||||
if (is_val<value_int>(property)) {
|
||||
object->stats.ops.insert("array_access");
|
||||
} else if (is_val<value_string>(property)) {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
// for converting from JSON to jinja values
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
#include <vector>
|
||||
@@ -160,6 +161,11 @@ static value tojson(const func_args & args) {
|
||||
value val_separators = args.get_kwarg_or_pos("separators", 3);
|
||||
value val_sort = args.get_kwarg_or_pos("sort_keys", 4);
|
||||
int indent = -1;
|
||||
if (args.ctx.is_get_stats) {
|
||||
// mark as used (recursively) for stats
|
||||
auto val_input = args.get_pos(0);
|
||||
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
|
||||
}
|
||||
if (is_val<value_int>(val_indent)) {
|
||||
indent = static_cast<int>(val_indent->as_int());
|
||||
}
|
||||
@@ -715,8 +721,46 @@ const func_builtins & value_string_t::get_builtins() const {
|
||||
return args.get_pos(0);
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"indent", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("String indent builtin not implemented");
|
||||
{"indent", [](const func_args &args) -> value {
|
||||
args.ensure_count(1, 4);
|
||||
value val_input = args.get_pos(0);
|
||||
value val_width = args.get_kwarg_or_pos("width", 1);
|
||||
const bool first = args.get_kwarg_or_pos("first", 2)->as_bool(); // undefined == false
|
||||
const bool blank = args.get_kwarg_or_pos("blank", 3)->as_bool(); // undefined == false
|
||||
if (!is_val<value_string>(val_input)) {
|
||||
throw raised_exception("indent() first argument must be a string");
|
||||
}
|
||||
std::string indent;
|
||||
if (is_val<value_int>(val_width)) {
|
||||
indent.assign(val_width->as_int(), ' ');
|
||||
} else if (is_val<value_string>(val_width)) {
|
||||
indent = val_width->as_string().str();
|
||||
} else {
|
||||
indent = " ";
|
||||
}
|
||||
std::string indented;
|
||||
std::string input = val_input->as_string().str();
|
||||
std::istringstream iss = std::istringstream(input);
|
||||
std::string line;
|
||||
while (std::getline(iss, line)) {
|
||||
if (!indented.empty()) {
|
||||
indented.push_back('\n');
|
||||
}
|
||||
if ((indented.empty() ? first : (!line.empty() || blank))) {
|
||||
indented += indent;
|
||||
}
|
||||
indented += line;
|
||||
}
|
||||
if (!input.empty() && input.back() == '\n') {
|
||||
indented.push_back('\n');
|
||||
if (blank) {
|
||||
indented += indent;
|
||||
}
|
||||
}
|
||||
|
||||
auto res = mk_val<value_string>(indented);
|
||||
res->val_str.mark_input_based_on(val_input->as_string());
|
||||
return res;
|
||||
}},
|
||||
{"join", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("String join builtin not implemented");
|
||||
@@ -852,6 +896,11 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
}},
|
||||
{"string", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_array>();
|
||||
if (args.ctx.is_get_stats) {
|
||||
// mark as used (recursively) for stats
|
||||
auto val_input = args.get_pos(0);
|
||||
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
|
||||
}
|
||||
return mk_val<value_string>(args.get_pos(0)->as_string());
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
@@ -1007,6 +1056,11 @@ const func_builtins & value_object_t::get_builtins() const {
|
||||
{"tojson", tojson},
|
||||
{"string", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
if (args.ctx.is_get_stats) {
|
||||
// mark as used (recursively) for stats
|
||||
auto val_input = args.get_pos(0);
|
||||
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
|
||||
}
|
||||
return mk_val<value_string>(args.get_pos(0)->as_string());
|
||||
}},
|
||||
{"length", [](const func_args & args) -> value {
|
||||
@@ -1319,4 +1373,21 @@ std::string value_to_string_repr(const value & val) {
|
||||
}
|
||||
}
|
||||
|
||||
// stats utility
|
||||
void value_t::stats_t::mark_used(value & val, bool deep) {
|
||||
val->stats.used = true;
|
||||
if (deep) {
|
||||
if (is_val<value_array>(val)) {
|
||||
for (auto & item : val->val_arr) {
|
||||
mark_used(item, deep);
|
||||
}
|
||||
} else if (is_val<value_object>(val)) {
|
||||
for (auto & pair : val->val_obj) {
|
||||
mark_used(pair.first, deep);
|
||||
mark_used(pair.second, deep);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -118,6 +118,8 @@ struct value_t {
|
||||
bool used = false;
|
||||
// ops can be builtin calls or operators: "array_access", "object_access"
|
||||
std::set<std::string> ops;
|
||||
// utility to recursively mark value and its children as used
|
||||
static void mark_used(value & val, bool deep = false);
|
||||
} stats;
|
||||
|
||||
value_t() = default;
|
||||
|
||||
@@ -116,7 +116,8 @@ class ModelBase:
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
|
||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
||||
disable_mistral_community_chat_template: bool = False,
|
||||
sentence_transformers_dense_modules: bool = False):
|
||||
sentence_transformers_dense_modules: bool = False,
|
||||
fuse_gate_up_exps: bool = False):
|
||||
if type(self) is ModelBase or \
|
||||
type(self) is TextModel or \
|
||||
type(self) is MmprojModel:
|
||||
@@ -135,6 +136,9 @@ class ModelBase:
|
||||
self.dry_run = dry_run
|
||||
self.remote_hf_model_id = remote_hf_model_id
|
||||
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
|
||||
self.fuse_gate_up_exps = fuse_gate_up_exps
|
||||
self._gate_exp_buffer: dict[int, Tensor] = {}
|
||||
self._up_exp_buffer: dict[int, Tensor] = {}
|
||||
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
|
||||
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
|
||||
self.metadata_override = metadata_override
|
||||
@@ -512,8 +516,31 @@ class ModelBase:
|
||||
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
# Handle gate/up expert tensor fusion if enabled
|
||||
if self.fuse_gate_up_exps and bid is not None:
|
||||
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid):
|
||||
self._gate_exp_buffer[bid] = data_torch
|
||||
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
|
||||
self._up_exp_buffer[bid] = data_torch
|
||||
|
||||
# Check if both gate and up are buffered for this layer
|
||||
if bid in self._gate_exp_buffer and bid in self._up_exp_buffer:
|
||||
gate_data = self._gate_exp_buffer.pop(bid)
|
||||
up_data = self._up_exp_buffer.pop(bid)
|
||||
# gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
|
||||
fused_data = torch.cat([gate_data, up_data], dim=1)
|
||||
fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
|
||||
logger.info(f"Fused gate_exps and up_exps for layer {bid}")
|
||||
return [(fused_name, fused_data)]
|
||||
|
||||
# If we buffered a gate/up tensor, wait for the other
|
||||
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid) or \
|
||||
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
|
||||
return []
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
del name, new_name, bid, n_dims # unused
|
||||
@@ -1049,6 +1076,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
|
||||
# ref: https://huggingface.co/zai-org/GLM-4.5-Air
|
||||
res = "glm4"
|
||||
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
|
||||
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
|
||||
res = "glm4"
|
||||
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
|
||||
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
|
||||
res = "minerva-7b"
|
||||
@@ -1082,9 +1112,6 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
|
||||
# ref: https://huggingface.co/aari1995/German_Semantic_V3
|
||||
res = "jina-v2-de"
|
||||
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
|
||||
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
|
||||
res = "glm4"
|
||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||
res = "llama-bpe"
|
||||
@@ -1148,6 +1175,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "27949a2493fc4a9f53f5b9b029c82689cfbe5d3a1929bb25e043089e28466de6":
|
||||
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-de
|
||||
res = "jina-v2-de"
|
||||
if chkhsh == "a023e9fdc5a11f034d3ef515b92350e56fb2af1f66c6b6811a4444ea9bf8763d":
|
||||
# ref: https://huggingface.co/jinaai/jina-embeddings-v5-text-nano
|
||||
res = "jina-v5-nano"
|
||||
if chkhsh == "c136ed14d01c2745d4f60a9596ae66800e2b61fa45643e72436041855ad4089d":
|
||||
# ref: https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct
|
||||
res = "smaug-bpe"
|
||||
@@ -1163,6 +1193,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901":
|
||||
# ref: https://huggingface.co/core42/jais-13b
|
||||
res = "jais"
|
||||
if chkhsh == "bc5108ee1eb6a3d600cadd065f63190fbd0554dbc9e4bbd6a0d977970afc8d2a":
|
||||
# ref: https://huggingface.co/inceptionai/Jais-2-8B-Chat
|
||||
res = "jais-2"
|
||||
if chkhsh == "7b3e7548e4308f52a76e8229e4e6cc831195d0d1df43aed21ac6c93da05fec5f":
|
||||
# ref: https://huggingface.co/WisdomShell/CodeShell-7B
|
||||
res = "codeshell"
|
||||
@@ -1268,6 +1301,12 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
|
||||
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
|
||||
res = "qwen35"
|
||||
if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d":
|
||||
# ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash
|
||||
res = "joyai-llm"
|
||||
if chkhsh == "e4d54df1ebc1f2b91acd986c5b51aa50837d5faf7c7398e73c1f9e9ee5d19869":
|
||||
# ref: https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601
|
||||
res = "kanana2"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@@ -3727,6 +3766,13 @@ class Ernie4_5Model(TextModel):
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
|
||||
if tokenizer_config_file.is_file():
|
||||
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
|
||||
tokenizer_config_json = json.load(f)
|
||||
if "add_prefix_space" in tokenizer_config_json:
|
||||
self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
@@ -3736,6 +3782,10 @@ class Ernie4_5Model(TextModel):
|
||||
if (head_dim := self.hparams.get("head_dim")) is None:
|
||||
head_dim = self.hparams["hidden_size"] // num_heads
|
||||
|
||||
if "mlp_AR" in name or "vision_model" in name:
|
||||
# skip vision model and projector tensors
|
||||
return
|
||||
|
||||
if "ernie." in name:
|
||||
name = name.replace("ernie.", "model.")
|
||||
# split the qkv weights
|
||||
@@ -3845,6 +3895,48 @@ class Ernie4_5MoeModel(Ernie4_5Model):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("PaddleOCRVLForConditionalGeneration")
|
||||
class PaddleOCRModel(Ernie4_5Model):
|
||||
model_arch = gguf.MODEL_ARCH.PADDLEOCR
|
||||
|
||||
|
||||
@ModelBase.register("PaddleOCRVisionModel")
|
||||
class PaddleOCRVisionModel(MmprojModel):
|
||||
# PaddleOCR-VL uses a modified version of Siglip
|
||||
min_pixels: int = 0
|
||||
max_pixels: int = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
self.min_pixels = self.preprocessor_config["min_pixels"]
|
||||
self.max_pixels = self.preprocessor_config["max_pixels"]
|
||||
self.hparams_vision["image_size"] = int(math.sqrt(self.max_pixels))
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
assert self.hparams_vision is not None
|
||||
hparams = self.hparams_vision
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PADDLEOCR)
|
||||
self.gguf_writer.add_vision_max_pixels(self.max_pixels)
|
||||
self.gguf_writer.add_vision_min_pixels(self.min_pixels)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-6))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
name = name.replace("visual.", "model.")
|
||||
|
||||
if "vision_model" in name or "mlp_AR" in name:
|
||||
if "packing_position_embedding" in name:
|
||||
return # unused
|
||||
elif "vision_model.head" in name:
|
||||
# we don't yet support image embeddings for this model
|
||||
return
|
||||
else:
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
return # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"Qwen2VLModel",
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
@@ -4581,7 +4673,7 @@ class Qwen3VLVisionModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration")
|
||||
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration")
|
||||
class Glm4VVisionModel(Qwen3VLVisionModel):
|
||||
def set_gguf_parameters(self):
|
||||
MmprojModel.set_gguf_parameters(self) # skip Qwen3VLVisionModel parameters
|
||||
@@ -6063,6 +6155,32 @@ class NeoBert(BertModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("EuroBertModel", "JinaEmbeddingsV5Model")
|
||||
class EuroBertModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.EUROBERT
|
||||
|
||||
def set_vocab(self):
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# EuroBert is bidirectional (encoder)
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
|
||||
self._try_set_pooling_type()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Strip "model." prefix from tensor names
|
||||
if name.startswith("model."):
|
||||
name = name[6:]
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
|
||||
class XLMRobertaModel(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.BERT
|
||||
@@ -8630,6 +8748,17 @@ class T5EncoderModel(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Jais2ForCausalLM")
|
||||
class Jais2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.JAIS2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
head_dim = hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_rope_dimension_count(head_dim)
|
||||
|
||||
|
||||
@ModelBase.register("JAISLMHeadModel")
|
||||
class JaisModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.JAIS
|
||||
@@ -8773,7 +8902,7 @@ class Glm4Model(TextModel):
|
||||
n_head = self.hparams["num_attention_heads"]
|
||||
n_kv_head = self.hparams["num_key_value_heads"]
|
||||
n_embd = self.hparams["hidden_size"]
|
||||
head_dim = n_embd // n_head
|
||||
head_dim = self.hparams.get("head_dim", n_embd // n_head)
|
||||
# because llama.cpp M-RoPE kernel only supports Neox ordering, we have to permute the weights here
|
||||
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
||||
data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_head, head_dim, self.partial_rotary_factor)
|
||||
@@ -8782,6 +8911,27 @@ class Glm4Model(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("GlmOcrForConditionalGeneration")
|
||||
class GlmOCRModel(Glm4Model):
|
||||
model_arch = gguf.MODEL_ARCH.GLM4
|
||||
use_mrope = False
|
||||
partial_rotary_factor = 0.5
|
||||
|
||||
# Note: GLM-OCR is the same as GLM4, but with an extra NextN/MTP prediction layer
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# GLM-OCR has num_hidden_layers + 1 actual layers (including NextN layer)
|
||||
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
# NextN/MTP prediction layers
|
||||
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
|
||||
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
|
||||
|
||||
|
||||
@ModelBase.register("Glm4MoeForCausalLM", "Glm4vMoeForConditionalGeneration")
|
||||
class Glm4MoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GLM4_MOE
|
||||
@@ -10702,7 +10852,7 @@ class LFM2Model(TextModel):
|
||||
def set_gguf_parameters(self):
|
||||
# set num_key_value_heads only for attention layers
|
||||
self.hparams["num_key_value_heads"] = [
|
||||
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
|
||||
self.hparams["num_key_value_heads"] if layer_type != "conv" else 0
|
||||
for layer_type in self.hparams["layer_types"]
|
||||
]
|
||||
|
||||
@@ -10888,6 +11038,28 @@ class LFM2AudioModel(ConformerAudioModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Lfm25AudioTokenizer")
|
||||
class LFM25AudioTokenizer(LFM2Model):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_none()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
|
||||
self.gguf_writer.add_embedding_length_out(self.hparams["output_size"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name == "istft.window" or name.startswith("emb.emb"):
|
||||
return
|
||||
|
||||
if name.startswith("lin"):
|
||||
name = name.replace("lin", "dense_2_out")
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("SmallThinkerForCausalLM")
|
||||
class SmallThinkerModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.SMALLTHINKER
|
||||
@@ -10979,13 +11151,17 @@ class ModernBertModel(BertModel):
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# these layers act as MLM head, so we don't need them
|
||||
if name.startswith("decoder."):
|
||||
return
|
||||
|
||||
if name.startswith("model."):
|
||||
name = name[6:]
|
||||
|
||||
if self.cls_out_labels:
|
||||
# For BertForSequenceClassification (direct projection layer)
|
||||
if name == "classifier.weight":
|
||||
name = "classifier.out_proj.weight"
|
||||
|
||||
if name == "classifier.bias":
|
||||
name = "classifier.out_proj.bias"
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@@ -11793,6 +11969,11 @@ def parse_args() -> argparse.Namespace:
|
||||
"Default these modules are not included.")
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fuse-gate-up-exps", action="store_true",
|
||||
help="Fuse gate_exps and up_exps tensors into a single gate_up_exps tensor for MoE models.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.print_supported_models and args.model is None:
|
||||
parser.error("the following arguments are required: model")
|
||||
@@ -11930,7 +12111,8 @@ def main() -> None:
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split,
|
||||
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
|
||||
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
|
||||
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules,
|
||||
fuse_gate_up_exps=args.fuse_gate_up_exps
|
||||
)
|
||||
|
||||
if args.vocab_only:
|
||||
|
||||
@@ -107,6 +107,7 @@ models = [
|
||||
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
||||
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
|
||||
{"name": "jina-v5-nano", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v5-text-nano", },
|
||||
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
|
||||
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
|
||||
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
|
||||
@@ -114,6 +115,7 @@ models = [
|
||||
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
|
||||
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
|
||||
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
|
||||
{"name": "jais-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inceptionai/Jais-2-8B-Chat", },
|
||||
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
|
||||
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
|
||||
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
|
||||
@@ -149,7 +151,9 @@ models = [
|
||||
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
|
||||
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
|
||||
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
|
||||
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
|
||||
{"name": "kanana2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
@@ -159,6 +163,7 @@ pre_computed_hashes = [
|
||||
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
|
||||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
|
||||
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
|
||||
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
|
||||
@@ -172,7 +177,6 @@ pre_computed_hashes = [
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
# jina-v2-de variants
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -246,7 +246,7 @@ cmake --build build --config release
|
||||
|
||||
1. **Retrieve and prepare model**
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration.
|
||||
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model prepration.
|
||||
|
||||
**Notes**:
|
||||
|
||||
|
||||
@@ -281,7 +281,7 @@ as `-cl-fp32-correctly-rounded-divide-sqrt`
|
||||
|
||||
#### Retrieve and prepare model
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
|
||||
##### Check device
|
||||
|
||||
@@ -569,7 +569,7 @@ Once it is completed, final results will be in **build/Release/bin**
|
||||
|
||||
#### Retrieve and prepare model
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
|
||||
##### Check device
|
||||
|
||||
|
||||
@@ -152,7 +152,9 @@ Commands and data are serialized using a custom binary protocol with:
|
||||
- **VM-specific**: Only works in virtual machines with virtio-gpu support
|
||||
- **Host dependency**: Requires properly configured host-side backend
|
||||
- **Latency**: Small overhead from VM escaping for each operation
|
||||
|
||||
- **Shared-memory size**: with the `libkrun` hypervisor, the RAM + VRAM
|
||||
addressable memory is limited to 64 GB. So the maximum GPU memory
|
||||
will be `64GB - RAM`, regardless of the hardware VRAM size.
|
||||
|
||||
* This work is pending upstream changes in the VirglRenderer
|
||||
project.
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
|
||||
**Llama.cpp + ZenDNN**
|
||||
|
||||
The llama.cpp ZenDNN backend leverages AMD's optimized matrix multiplication primitives to accelerate inference on AMD CPUs. It utilizes ZenDNN's **LowOHA (Low Overhead Hardware Accelerated)** MatMul operator for efficient GEMM operations with minimal execution overhead, built-in weight caching, and direct access to backend libraries (AOCL BLIS, LibXSMM, OneDNN).
|
||||
The llama.cpp ZenDNN backend leverages AMD's optimized matrix multiplication primitives to accelerate inference on AMD CPUs. It utilizes ZenDNN's **LowOHA (Low Overhead Hardware Accelerated)** MatMul operator for efficient GEMM operations with minimal execution overhead, built-in weight caching, and direct access to backend libraries (AOCL DLP, LibXSMM, OneDNN).
|
||||
|
||||
For more information about ZenDNN, visit: https://www.amd.com/en/developer/zendnn.html
|
||||
|
||||
@@ -32,7 +32,7 @@ For more information about ZenDNN, visit: https://www.amd.com/en/developer/zendn
|
||||
|:-------:|:-------:|:----------------------------------------------:|
|
||||
| Linux | Support | Ubuntu 20.04, 22.04, 24.04 |
|
||||
|
||||
For the latest list of supported operating systems, see the [ZenDNN Supported OS](https://github.com/amd/ZenDNN/blob/zendnnl/README.md#15-supported-os).
|
||||
For the latest list of supported operating systems, see the [ZenDNN Supported OS](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/README.md#15-supported-os).
|
||||
|
||||
## Hardware
|
||||
|
||||
@@ -44,9 +44,9 @@ ZenDNN is optimized for AMD EPYC™ processors and AMD Ryzen™ processors based
|
||||
|
||||
| CPU Family | Status | Notes |
|
||||
|:-----------------------------:|:-------:|:----------------------------------:|
|
||||
| AMD EPYC™ 9005 Series (Turin)| Support | 5th Gen - Zen 5 architecture |
|
||||
| AMD EPYC™ 9004 Series (Genoa)| Support | 4th Gen - Zen 4 architecture |
|
||||
| AMD EPYC™ 7003 Series (Milan)| Support | 3rd Gen - Zen 3 architecture |
|
||||
| AMD EPYC™ 9005 Series (Turin) | Support | 5th Gen - Zen 5 architecture |
|
||||
| AMD EPYC™ 9004 Series (Genoa) | Support | 4th Gen - Zen 4 architecture |
|
||||
| AMD EPYC™ 7003 Series (Milan) | Support | 3rd Gen - Zen 3 architecture |
|
||||
| AMD Ryzen™ AI MAX (Strix Halo)| Support | High-performance mobile processors |
|
||||
|
||||
*Notes:*
|
||||
@@ -61,7 +61,7 @@ The ZenDNN backend currently accelerates **matrix multiplication (MUL_MAT)** ope
|
||||
|
||||
| Operation | Status | Notes |
|
||||
|:-------------|:-------:|:----------------------------------------------:|
|
||||
| MUL_MAT | ✓ | Accelerated via ZenDNN LowOHA MatMul |
|
||||
| MUL_MAT | Support | Accelerated via ZenDNN LowOHA MatMul |
|
||||
|
||||
*Note:* Since only MUL_MAT is accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs).
|
||||
|
||||
@@ -104,7 +104,6 @@ If you want to build ZenDNN yourself or use a specific version:
|
||||
# Clone ZenDNN repository
|
||||
git clone https://github.com/amd/ZenDNN.git
|
||||
cd ZenDNN
|
||||
git checkout zendnnl
|
||||
|
||||
# Build and install (requires CMake >= 3.25)
|
||||
mkdir build && cd build
|
||||
@@ -114,7 +113,7 @@ cmake --build . --target all
|
||||
|
||||
Default installation path: `ZenDNN/build/install`
|
||||
|
||||
**For detailed build instructions**, refer to the [ZenDNN README](https://github.com/amd/ZenDNN/blob/zendnnl/README.md).
|
||||
**For detailed build instructions**, refer to the [ZenDNN README](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/README.md).
|
||||
|
||||
**Step 2: Build llama.cpp with custom ZenDNN path**
|
||||
|
||||
@@ -146,8 +145,7 @@ Run llama.cpp server with ZenDNN acceleration:
|
||||
|
||||
```sh
|
||||
# Set optimal configuration
|
||||
export OMP_NUM_THREADS=64 # Adjust to your CPU core count
|
||||
export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS for best performance
|
||||
export ZENDNNL_MATMUL_ALGO=1 # Blocked AOCL DLP algo for best performance
|
||||
|
||||
# Start server
|
||||
./build/bin/llama-server \
|
||||
@@ -160,62 +158,26 @@ export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS for best performance
|
||||
Access the server at `http://localhost:8080`.
|
||||
|
||||
**Performance tips**:
|
||||
- Set `OMP_NUM_THREADS` to match your physical core count
|
||||
- Use `ZENDNNL_MATMUL_ALGO=2` for optimal performance
|
||||
- Use `ZENDNNL_MATMUL_ALGO=1` for optimal performance
|
||||
- For NUMA systems: `numactl --cpunodebind=0 --membind=0 ./build/bin/llama-server ...`
|
||||
|
||||
## Environment Variable
|
||||
|
||||
### Build Time
|
||||
For environment variables related to ZenDNN, refer to the [ZenDNN Environment Variables Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/runtime_env.md).
|
||||
|
||||
| Name | Value | Function |
|
||||
|--------------------|---------------------------------------|---------------------------------------------|
|
||||
| GGML_ZENDNN | ON/OFF | Enable ZenDNN backend support |
|
||||
| ZENDNN_ROOT | Path to ZenDNN installation | Set ZenDNN installation directory |
|
||||
| GGML_OPENMP | ON/OFF (recommended: ON) | Enable OpenMP for multi-threading |
|
||||
### Performance Optimization
|
||||
|
||||
### Runtime
|
||||
|
||||
| Name | Value | Function |
|
||||
|-------------------------|--------------------------|-------------------------------------------------------------------|
|
||||
| OMP_NUM_THREADS | Number (e.g., 64) | Set number of OpenMP threads (recommended: physical core count) |
|
||||
| ZENDNNL_MATMUL_ALGO | 0-5 | Select MatMul backend algorithm (see Performance Optimization) |
|
||||
| ZENDNNL_PROFILE_LOG_LEVEL | 0-4 | Profiling log level (0=disabled, 4=verbose) |
|
||||
| ZENDNNL_ENABLE_PROFILER | 0 or 1 | Enable detailed profiling (1=enabled) |
|
||||
| ZENDNNL_API_LOG_LEVEL | 0-4 | API log level (0=disabled, 4=verbose) |
|
||||
|
||||
**Example**:
|
||||
ZenDNN's LowOHA MatMul supports multiple backend algorithms. For **best performance**, use the **Blocked AOCL DLP** algorithm:
|
||||
|
||||
```sh
|
||||
export OMP_NUM_THREADS=64
|
||||
export ZENDNNL_MATMUL_ALGO=2 # Use Blocked AOCL BLIS for best performance
|
||||
./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "Test" -n 100
|
||||
export ZENDNNL_MATMUL_ALGO=1 # Blocked AOCL DLP algo (recommended)
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### MatMul Algorithm Selection
|
||||
|
||||
ZenDNN's LowOHA MatMul supports multiple backend algorithms. For **best performance**, use the **Blocked AOCL BLIS** algorithm:
|
||||
|
||||
```sh
|
||||
export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS (recommended)
|
||||
```
|
||||
|
||||
**Available algorithms**:
|
||||
|
||||
| Value | Algorithm | Description |
|
||||
|:-----:|:-----------------------|:----------------------------------------------|
|
||||
| 0 | Dynamic Dispatch | Automatic backend selection (default) |
|
||||
| 1 | AOCL BLIS | AOCL BLIS backend |
|
||||
| 2 | AOCL BLIS Blocked | **Blocked AOCL BLIS (recommended)** |
|
||||
| 3 | OneDNN | OneDNN backend |
|
||||
| 4 | OneDNN Blocked | Blocked OneDNN |
|
||||
| 5 | LibXSMM | LibXSMM backend |
|
||||
For more details on available algorithms, see the [ZenDNN MatMul Algorithm Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/runtime_env.md#algorithm-details).
|
||||
|
||||
### Profiling and Debugging
|
||||
|
||||
For detailed profiling and logging options, refer to the [ZenDNN Logging Documentation](https://github.com/amd/ZenDNN/blob/zendnnl/docs/logging.md).
|
||||
For detailed profiling and logging options, refer to the [ZenDNN Logging Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/logging.md).
|
||||
|
||||
## Known Issues
|
||||
|
||||
@@ -245,10 +207,9 @@ A: Currently, ZenDNN primarily supports FP32 and BF16 data types. Quantized mode
|
||||
|
||||
A: Ensure:
|
||||
1. You're using an AMD EPYC or Ryzen processor (Zen 2 or newer)
|
||||
2. `OMP_NUM_THREADS` is set appropriately (physical core count)
|
||||
3. `ZENDNNL_MATMUL_ALGO=2` is set for best performance (Blocked AOCL BLIS)
|
||||
4. You're using a sufficiently large model (small models may not benefit as much)
|
||||
5. Enable profiling to verify ZenDNN MatMul is being called
|
||||
2. `ZENDNNL_MATMUL_ALGO=1` is set for best performance (Blocked AOCL DLP)
|
||||
3. You're using a sufficiently large model (small models may not benefit as much)
|
||||
4. Enable profiling to verify ZenDNN MatMul is being called
|
||||
|
||||
### **GitHub Contribution**:
|
||||
Please add the **[ZenDNN]** prefix/tag in issues/PRs titles to help the ZenDNN-team check/address them without delay.
|
||||
|
||||
@@ -31,7 +31,7 @@ Legend:
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
@@ -96,13 +96,13 @@ Legend:
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
|
||||
@@ -8760,22 +8760,14 @@
|
||||
"WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=1","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=32","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=129","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","SQR","type=f16,ne=[10,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","SQRT","type=f16,ne=[10,3,3,2]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","LOG","type=f16,ne=[10,5,4,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SIN","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","COS","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","FLOOR","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CEIL","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","ROUND","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","TRUNC","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQR","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SIN","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","COS","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
@@ -8786,22 +8778,14 @@
|
||||
"WebGPU: WebGPU","ROUND","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","TRUNC","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQR","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","SQRT","type=f32,ne=[10,3,3,2]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SIN","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","COS","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQR","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","SQRT","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SIN","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","COS","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
@@ -18901,3 +18885,27 @@
|
||||
"WebGPU: WebGPU","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","SQR","type=f16,ne=[10,5,4,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQRT","type=f16,ne=[10,3,3,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SIN","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","COS","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQR","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQR","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQRT","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQRT","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SIN","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SIN","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","COS","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","COS","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SIN","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","COS","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQR","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SQRT","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SIN","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SIN","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","COS","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","COS","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU"
|
||||
|
||||
|
Can't render this file because it is too large.
|
@@ -77,7 +77,10 @@ causal-verify-embeddings: causal-run-original-embeddings causal-run-converted-em
|
||||
@./scripts/causal/compare-embeddings-logits.sh
|
||||
|
||||
causal-inspect-original-model:
|
||||
@./scripts/utils/inspect-org-model.py
|
||||
@./scripts/utils/inspect-org-model.py --list-all -s
|
||||
|
||||
causal-list-original-model-tensors:
|
||||
@./scripts/utils/inspect-org-model.py --list-all-short -s
|
||||
|
||||
causal-inspect-converted-model:
|
||||
@./scripts/utils/inspect-converted-model.sh
|
||||
@@ -153,7 +156,7 @@ embedding-verify-logits-st: embedding-run-original-model-st embedding-run-conver
|
||||
|
||||
embedding-inspect-original-model:
|
||||
$(call validate_embedding_model_path,embedding-inspect-original-model)
|
||||
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH}
|
||||
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH} --list-all -s
|
||||
|
||||
embedding-inspect-converted-model:
|
||||
@CONVERTED_EMBEDDING_MODEL="$(CONVERTED_EMBEDDING_MODEL)" ./scripts/utils/inspect-converted-model.sh ${CONVERTED_EMBEDDING_MODEL}
|
||||
|
||||
@@ -42,11 +42,15 @@ def load_model_and_tokenizer(model_path, device="auto"):
|
||||
config = config.text_config
|
||||
multimodal = True
|
||||
|
||||
print("Vocab size: ", config.vocab_size)
|
||||
print("Hidden size: ", config.hidden_size)
|
||||
print("Number of layers: ", config.num_hidden_layers)
|
||||
print("BOS token id: ", config.bos_token_id)
|
||||
print("EOS token id: ", config.eos_token_id)
|
||||
def print_if_exists(label, obj, attr, default="N/A"):
|
||||
val = getattr(obj, attr) if hasattr(obj, attr) else default
|
||||
print(f"{label}", val)
|
||||
|
||||
print_if_exists("Vocab size: ", config, "vocab_size")
|
||||
print_if_exists("Hidden size: ", config, "hidden_size")
|
||||
print_if_exists("Number of layers: ", config, "num_hidden_layers")
|
||||
print_if_exists("BOS token id: ", config, "bos_token_id")
|
||||
print_if_exists("EOS token id: ", config, "eos_token_id")
|
||||
|
||||
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
|
||||
if unreleased_model_name:
|
||||
|
||||
@@ -1,67 +1,290 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from safetensors import safe_open
|
||||
from collections import defaultdict
|
||||
|
||||
parser = argparse.ArgumentParser(description='Process model with specified path')
|
||||
parser.add_argument('--model-path', '-m', help='Path to the model')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = os.environ.get('MODEL_PATH', args.model_path)
|
||||
if model_path is None:
|
||||
parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
|
||||
MODEL_SAFETENSORS_FILE = "model.safetensors"
|
||||
MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"
|
||||
|
||||
# Check if there's an index file (multi-file model)
|
||||
index_path = os.path.join(model_path, "model.safetensors.index.json")
|
||||
single_file_path = os.path.join(model_path, "model.safetensors")
|
||||
DTYPE_SIZES = {
|
||||
"F64": 8, "I64": 8, "U64": 8,
|
||||
"F32": 4, "I32": 4, "U32": 4,
|
||||
"F16": 2, "BF16": 2, "I16": 2, "U16": 2,
|
||||
"I8": 1, "U8": 1, "BOOL": 1,
|
||||
"F8_E4M3": 1, "F8_E5M2": 1,
|
||||
}
|
||||
|
||||
if os.path.exists(index_path):
|
||||
# Multi-file model
|
||||
print("Multi-file model detected")
|
||||
SIZE_UNITS = ['B', 'KB', 'MB', 'GB', 'TB']
|
||||
|
||||
with open(index_path, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
# Get the weight map (tensor_name -> file_name)
|
||||
weight_map = index_data.get("weight_map", {})
|
||||
def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
|
||||
index_file = model_path / MODEL_SAFETENSORS_INDEX
|
||||
|
||||
# Group tensors by file for efficient processing
|
||||
file_tensors = defaultdict(list)
|
||||
for tensor_name, file_name in weight_map.items():
|
||||
file_tensors[file_name].append(tensor_name)
|
||||
if index_file.exists():
|
||||
with open(index_file, 'r') as f:
|
||||
index = json.load(f)
|
||||
return index.get("weight_map", {})
|
||||
|
||||
print("Tensors in model:")
|
||||
return None
|
||||
|
||||
# Process each shard file
|
||||
for file_name, tensor_names in file_tensors.items():
|
||||
file_path = os.path.join(model_path, file_name)
|
||||
print(f"\n--- From {file_name} ---")
|
||||
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
for tensor_name in sorted(tensor_names):
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}")
|
||||
def get_all_tensor_names(model_path: Path) -> list[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
elif os.path.exists(single_file_path):
|
||||
# Single file model (original behavior)
|
||||
print("Single-file model detected")
|
||||
if weight_map is not None:
|
||||
return list(weight_map.keys())
|
||||
|
||||
with safe_open(single_file_path, framework="pt") as f:
|
||||
keys = f.keys()
|
||||
print("Tensors in model:")
|
||||
for key in sorted(keys):
|
||||
tensor = f.get_tensor(key)
|
||||
print(f"- {key} : shape = {tensor.shape}, dtype = {tensor.dtype}")
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
try:
|
||||
with safe_open(single_file, framework="pt", device="cpu") as f:
|
||||
return list(f.keys())
|
||||
except Exception as e:
|
||||
print(f"Error reading {single_file}: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
else:
|
||||
print(f"Error: Neither 'model.safetensors.index.json' nor 'model.safetensors' found in {model_path}")
|
||||
print("Available files:")
|
||||
if os.path.exists(model_path):
|
||||
for item in sorted(os.listdir(model_path)):
|
||||
print(f" {item}")
|
||||
print(f"Error: No safetensors files found in {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
return weight_map.get(tensor_name)
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
return single_file.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def read_safetensors_header(file_path: Path) -> dict:
|
||||
with open(file_path, 'rb') as f:
|
||||
header_size = struct.unpack('<Q', f.read(8))[0]
|
||||
return json.loads(f.read(header_size))
|
||||
|
||||
|
||||
def get_tensor_size_bytes(tensor_meta: dict) -> int:
|
||||
offsets = tensor_meta.get("data_offsets")
|
||||
if offsets and len(offsets) == 2:
|
||||
return offsets[1] - offsets[0]
|
||||
n_elements = 1
|
||||
for d in tensor_meta.get("shape", []):
|
||||
n_elements *= d
|
||||
return n_elements * DTYPE_SIZES.get(tensor_meta.get("dtype", "F32"), 4)
|
||||
|
||||
|
||||
def format_size(size_bytes: int) -> str:
|
||||
val = float(size_bytes)
|
||||
for unit in SIZE_UNITS[:-1]:
|
||||
if val < 1024.0:
|
||||
return f"{val:.2f} {unit}"
|
||||
val /= 1024.0
|
||||
return f"{val:.2f} {SIZE_UNITS[-1]}"
|
||||
|
||||
|
||||
def get_all_tensor_metadata(model_path: Path) -> dict[str, dict]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
file_to_tensors: dict[str, list[str]] = {}
|
||||
for tensor_name, file_name in weight_map.items():
|
||||
file_to_tensors.setdefault(file_name, []).append(tensor_name)
|
||||
|
||||
all_metadata: dict[str, dict] = {}
|
||||
for file_name, tensor_names in file_to_tensors.items():
|
||||
try:
|
||||
header = read_safetensors_header(model_path / file_name)
|
||||
for tensor_name in tensor_names:
|
||||
if tensor_name in header:
|
||||
all_metadata[tensor_name] = header[tensor_name]
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not read header from {file_name}: {e}", file=sys.stderr)
|
||||
return all_metadata
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
try:
|
||||
header = read_safetensors_header(single_file)
|
||||
return {k: v for k, v in header.items() if k != "__metadata__"}
|
||||
except Exception as e:
|
||||
print(f"Error reading {single_file}: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Error: No safetensors files found in {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def normalize_tensor_name(tensor_name: str) -> str:
|
||||
normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
|
||||
normalized = re.sub(r'\.\d+$', '.#', normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def list_all_tensors(
|
||||
model_path: Path,
|
||||
short: bool = False,
|
||||
show_sizes: bool = False,
|
||||
):
|
||||
tensor_names = get_all_tensor_names(model_path)
|
||||
|
||||
metadata: Optional[dict[str, dict]] = None
|
||||
if show_sizes:
|
||||
metadata = get_all_tensor_metadata(model_path)
|
||||
|
||||
total_bytes = 0
|
||||
|
||||
if short:
|
||||
seen: dict[str, str] = {}
|
||||
for tensor_name in sorted(tensor_names):
|
||||
normalized = normalize_tensor_name(tensor_name)
|
||||
if normalized not in seen:
|
||||
seen[normalized] = tensor_name
|
||||
display_pairs = list(sorted(seen.items()))
|
||||
name_width = max((len(n) for n, _ in display_pairs), default=0)
|
||||
for normalized, first_name in display_pairs:
|
||||
if metadata and first_name in metadata:
|
||||
m = metadata[first_name]
|
||||
size = get_tensor_size_bytes(m)
|
||||
total_bytes += size
|
||||
print(f"{normalized:{name_width}} {m.get('dtype', '?'):6s} {str(m.get('shape', '')):30s} {format_size(size)}")
|
||||
else:
|
||||
print(normalized)
|
||||
else:
|
||||
print(f" Directory {model_path} does not exist")
|
||||
exit(1)
|
||||
name_width = max((len(n) for n in tensor_names), default=0)
|
||||
for tensor_name in sorted(tensor_names):
|
||||
if metadata and tensor_name in metadata:
|
||||
m = metadata[tensor_name]
|
||||
size = get_tensor_size_bytes(m)
|
||||
total_bytes += size
|
||||
print(f"{tensor_name:{name_width}} {m.get('dtype', '?'):6s} {str(m.get('shape', '')):30s} {format_size(size)}")
|
||||
else:
|
||||
print(tensor_name)
|
||||
|
||||
if show_sizes:
|
||||
print(f"\nTotal: {format_size(total_bytes)}")
|
||||
|
||||
|
||||
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
|
||||
tensor_file = find_tensor_file(model_path, tensor_name)
|
||||
|
||||
if tensor_file is None:
|
||||
print(f"Error: Could not find tensor '{tensor_name}' in model index")
|
||||
print(f"Model path: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
file_path = model_path / tensor_file
|
||||
|
||||
try:
|
||||
header = read_safetensors_header(file_path)
|
||||
tensor_meta = header.get(tensor_name, {})
|
||||
dtype_str = tensor_meta.get("dtype")
|
||||
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
if tensor_name in f.keys():
|
||||
tensor_slice = f.get_slice(tensor_name)
|
||||
shape = tensor_slice.get_shape()
|
||||
print(f"Tensor: {tensor_name}")
|
||||
print(f"File: {tensor_file}")
|
||||
print(f"Shape: {shape}")
|
||||
if dtype_str:
|
||||
print(f"Dtype: {dtype_str}")
|
||||
if tensor_meta:
|
||||
print(f"Size: {format_size(get_tensor_size_bytes(tensor_meta))}")
|
||||
if num_values is not None:
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
if not dtype_str:
|
||||
print(f"Dtype: {tensor.dtype}")
|
||||
flat = tensor.flatten()
|
||||
n = min(num_values, flat.numel())
|
||||
print(f"Values: {flat[:n].tolist()}")
|
||||
else:
|
||||
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
|
||||
sys.exit(1)
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file '{file_path}' was not found.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Print tensor information from a safetensors model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"tensor_name",
|
||||
nargs="?",
|
||||
help="Name of the tensor to inspect"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m", "--model-path",
|
||||
type=Path,
|
||||
help="Path to the model directory (default: MODEL_PATH environment variable)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list-all-short",
|
||||
action="store_true",
|
||||
help="List unique tensor patterns (layer numbers replaced with #)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-la", "--list-all",
|
||||
action="store_true",
|
||||
help="List all tensor names with actual layer numbers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--num-values",
|
||||
nargs="?",
|
||||
const=10,
|
||||
default=None,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s", "--sizes",
|
||||
action="store_true",
|
||||
help="Show dtype, shape, and size for each tensor when listing"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model_path
|
||||
if model_path is None:
|
||||
model_path_str = os.environ.get("MODEL_PATH")
|
||||
if model_path_str is None:
|
||||
print("Error: --model-path not provided and MODEL_PATH environment variable not set")
|
||||
sys.exit(1)
|
||||
model_path = Path(model_path_str)
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model path does not exist: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not model_path.is_dir():
|
||||
print(f"Error: Model path is not a directory: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.list_all_short or args.list_all:
|
||||
list_all_tensors(model_path, short=args.list_all_short, show_sizes=args.sizes)
|
||||
else:
|
||||
if args.tensor_name is None:
|
||||
print("Error: tensor_name is required when not using --list-all-short or --list-all")
|
||||
sys.exit(1)
|
||||
print_tensor_info(model_path, args.tensor_name, args.num_values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
MODEL_SAFETENSORS_FILE = "model.safetensors"
|
||||
MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"
|
||||
|
||||
|
||||
def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
|
||||
index_file = model_path / MODEL_SAFETENSORS_INDEX
|
||||
|
||||
if index_file.exists():
|
||||
with open(index_file, 'r') as f:
|
||||
index = json.load(f)
|
||||
return index.get("weight_map", {})
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_all_tensor_names(model_path: Path) -> list[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
return list(weight_map.keys())
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
try:
|
||||
with safe_open(single_file, framework="pt", device="cpu") as f:
|
||||
return list(f.keys())
|
||||
except Exception as e:
|
||||
print(f"Error reading {single_file}: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Error: No safetensors files found in {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
return weight_map.get(tensor_name)
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
return single_file.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def normalize_tensor_name(tensor_name: str) -> str:
|
||||
normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
|
||||
normalized = re.sub(r'\.\d+$', '.#', normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def list_all_tensors(model_path: Path, unique: bool = False):
|
||||
tensor_names = get_all_tensor_names(model_path)
|
||||
|
||||
if unique:
|
||||
seen = set()
|
||||
for tensor_name in sorted(tensor_names):
|
||||
normalized = normalize_tensor_name(tensor_name)
|
||||
if normalized not in seen:
|
||||
seen.add(normalized)
|
||||
print(normalized)
|
||||
else:
|
||||
for tensor_name in sorted(tensor_names):
|
||||
print(tensor_name)
|
||||
|
||||
|
||||
def print_tensor_info(model_path: Path, tensor_name: str):
|
||||
tensor_file = find_tensor_file(model_path, tensor_name)
|
||||
|
||||
if tensor_file is None:
|
||||
print(f"Error: Could not find tensor '{tensor_name}' in model index")
|
||||
print(f"Model path: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
file_path = model_path / tensor_file
|
||||
|
||||
try:
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
if tensor_name in f.keys():
|
||||
tensor_slice = f.get_slice(tensor_name)
|
||||
shape = tensor_slice.get_shape()
|
||||
print(f"Tensor: {tensor_name}")
|
||||
print(f"File: {tensor_file}")
|
||||
print(f"Shape: {shape}")
|
||||
else:
|
||||
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
|
||||
sys.exit(1)
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file '{file_path}' was not found.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Print tensor information from a safetensors model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"tensor_name",
|
||||
nargs="?", # optional (if --list is used for example)
|
||||
help="Name of the tensor to inspect"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m", "--model-path",
|
||||
type=Path,
|
||||
help="Path to the model directory (default: MODEL_PATH environment variable)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list",
|
||||
action="store_true",
|
||||
help="List unique tensor patterns in the model (layer numbers replaced with #)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model_path
|
||||
if model_path is None:
|
||||
model_path_str = os.environ.get("MODEL_PATH")
|
||||
if model_path_str is None:
|
||||
print("Error: --model-path not provided and MODEL_PATH environment variable not set")
|
||||
sys.exit(1)
|
||||
model_path = Path(model_path_str)
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model path does not exist: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not model_path.is_dir():
|
||||
print(f"Error: Model path is not a directory: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.list:
|
||||
list_all_tensors(model_path, unique=True)
|
||||
else:
|
||||
if args.tensor_name is None:
|
||||
print("Error: tensor_name is required when not using --list")
|
||||
sys.exit(1)
|
||||
print_tensor_info(model_path, args.tensor_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,12 +5,15 @@
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.prompt = "The quick brown fox";
|
||||
params.sampling.seed = 1234;
|
||||
|
||||
const std::string_view state_file = "dump_state.bin";
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -53,35 +56,16 @@ int main(int argc, char ** argv) {
|
||||
// tokenize prompt
|
||||
auto tokens = common_tokenize(ctx, params.prompt, true);
|
||||
|
||||
// prepare the batch
|
||||
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
common_batch_add(batch, tokens[i], i, {0}, false);
|
||||
const bool save_state = true;
|
||||
if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) {
|
||||
return 1;
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true; // generate next token
|
||||
|
||||
// evaluate prompt
|
||||
llama_decode(ctx, batch);
|
||||
n_past += batch.n_tokens;
|
||||
|
||||
// save state (rng, logits, embedding and kv_cache) to file
|
||||
{
|
||||
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
|
||||
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
|
||||
|
||||
FILE *fp_write = fopen("dump_state.bin", "wb");
|
||||
fwrite(state_mem.data(), 1, written, fp_write);
|
||||
fclose(fp_write);
|
||||
|
||||
fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
|
||||
}
|
||||
|
||||
// save state (last tokens)
|
||||
const auto n_past_saved = n_past;
|
||||
|
||||
// first run
|
||||
printf("\nfirst run: %s", params.prompt.c_str());
|
||||
|
||||
llama_batch batch = llama_batch_init(1, 0, 1);
|
||||
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto next_token = llama_sampler_sample(smpl, ctx, -1);
|
||||
auto next_token_str = common_token_to_piece(ctx, next_token);
|
||||
@@ -111,27 +95,23 @@ int main(int argc, char ** argv) {
|
||||
|
||||
printf("\nsecond run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
{
|
||||
std::vector<uint8_t> state_mem;
|
||||
// load state from file
|
||||
std::vector<llama_token> unused_sts(tokens.size()); // unused session tokens.
|
||||
size_t n_token_count_out = 0;
|
||||
|
||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||
fseek(fp_read, 0, SEEK_END);
|
||||
state_mem.resize(ftell(fp_read));
|
||||
fseek(fp_read, 0, SEEK_SET);
|
||||
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||
fclose(fp_read);
|
||||
|
||||
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
|
||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
|
||||
if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_past_saved;
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// second run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
@@ -160,7 +140,9 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// make new context
|
||||
llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
|
||||
auto params_ctx3 = common_context_params_to_llama(params);
|
||||
params_ctx3.n_seq_max = 2;
|
||||
llama_context * ctx3 = llama_init_from_model(model, params_ctx3);
|
||||
|
||||
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
|
||||
|
||||
@@ -169,26 +151,21 @@ int main(int argc, char ** argv) {
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
{
|
||||
std::vector<uint8_t> state_mem;
|
||||
n_token_count_out = 0;
|
||||
|
||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||
fseek(fp_read, 0, SEEK_END);
|
||||
state_mem.resize(ftell(fp_read));
|
||||
fseek(fp_read, 0, SEEK_SET);
|
||||
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||
fclose(fp_read);
|
||||
|
||||
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
|
||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
|
||||
if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_past_saved;
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// save seq 0 and load into seq 1
|
||||
{
|
||||
|
||||
@@ -730,10 +730,6 @@ extern "C" {
|
||||
GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
|
||||
GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
|
||||
|
||||
GGML_DEPRECATED(
|
||||
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
|
||||
"use ggml_row_size() instead");
|
||||
|
||||
GGML_API const char * ggml_type_name(enum ggml_type type);
|
||||
GGML_API const char * ggml_op_name (enum ggml_op op);
|
||||
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
||||
|
||||
@@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch)
|
||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN})
|
||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
|
||||
set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
# Disable LTO for the feature detection code to prevent cross-module optimization
|
||||
# from inlining architecture-specific instructions into the score function.
|
||||
# Without this, LTO can cause SIGILL when loading backends on older CPUs
|
||||
# (e.g., loading power10 backend on power9 crashes before feature check runs).
|
||||
target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto)
|
||||
target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME})
|
||||
endfunction()
|
||||
|
||||
|
||||
@@ -141,27 +141,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ
|
||||
namespace ggml::cpu::amx {
|
||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||
// handle only 2d gemm for now
|
||||
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
|
||||
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
|
||||
};
|
||||
|
||||
if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
|
||||
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
|
||||
op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
|
||||
op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
|
||||
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
|
||||
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
|
||||
// src1 must be host buffer
|
||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
// src1 must be float32
|
||||
if (op->src[1]->type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (op->op != GGML_OP_MUL_MAT) {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
auto * src0 = op->src[0];
|
||||
auto * src1 = op->src[1];
|
||||
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
return false;
|
||||
}
|
||||
if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) {
|
||||
return false;
|
||||
}
|
||||
if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
if (op->ne[0] % (TILE_N * 2)) {
|
||||
return false;
|
||||
}
|
||||
int alignment;
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
alignment = TILE_K;
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
alignment = 256; // QK_K
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
alignment = 16;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
if (src0->ne[0] % alignment) {
|
||||
return false;
|
||||
}
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#pragma GCC diagnostic ignored "-Wpedantic"
|
||||
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
|
||||
@@ -202,35 +201,27 @@ struct tile_config_t{
|
||||
// advanced-matrix-extensions-intrinsics-functions.html
|
||||
//
|
||||
|
||||
#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
|
||||
void ggml_tile_config_init(void) {
|
||||
static thread_local bool is_first_time = true;
|
||||
inline void ggml_tile_config_init(void) {
|
||||
static thread_local bool done = false;
|
||||
|
||||
if (!is_first_time) {
|
||||
if (done) {
|
||||
return;
|
||||
}
|
||||
|
||||
static thread_local tile_config_t tc;
|
||||
tile_config_t current_tc;
|
||||
_tile_storeconfig(¤t_tc);
|
||||
alignas(64) tile_config_t tc = {};
|
||||
tc.palette_id = 1;
|
||||
tc.start_row = 0;
|
||||
tc.rows[0] = 8; tc.colsb[0] = 64;
|
||||
tc.rows[1] = 8; tc.colsb[1] = 64;
|
||||
tc.rows[2] = 16; tc.colsb[2] = 32;
|
||||
tc.rows[3] = 16; tc.colsb[3] = 32;
|
||||
tc.rows[4] = 16; tc.colsb[4] = 64;
|
||||
tc.rows[5] = 16; tc.colsb[5] = 64;
|
||||
tc.rows[6] = 16; tc.colsb[6] = 64;
|
||||
tc.rows[7] = 16; tc.colsb[7] = 64;
|
||||
|
||||
// load only when config changes
|
||||
if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
|
||||
memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
|
||||
tc.palette_id = 1;
|
||||
tc.start_row = 0;
|
||||
TC_CONFIG_TILE(TMM0, 8, 64);
|
||||
TC_CONFIG_TILE(TMM1, 8, 64);
|
||||
TC_CONFIG_TILE(TMM2, 16, 32);
|
||||
TC_CONFIG_TILE(TMM3, 16, 32);
|
||||
TC_CONFIG_TILE(TMM4, 16, 64);
|
||||
TC_CONFIG_TILE(TMM5, 16, 64);
|
||||
TC_CONFIG_TILE(TMM6, 16, 64);
|
||||
TC_CONFIG_TILE(TMM7, 16, 64);
|
||||
_tile_loadconfig(&tc);
|
||||
}
|
||||
|
||||
is_first_time = false;
|
||||
_tile_loadconfig(&tc);
|
||||
done = true;
|
||||
}
|
||||
|
||||
// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
|
||||
@@ -268,33 +259,6 @@ int get_row_size(int K) {
|
||||
return row_size;
|
||||
}
|
||||
|
||||
// vectorized dtype conversion
|
||||
inline float FP16_TO_FP32(ggml_half val) {
|
||||
__m256i v = _mm256_setr_epi16(
|
||||
val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||||
__m512 o = _mm512_cvtph_ps(v);
|
||||
return _mm512_cvtss_f32(o);
|
||||
}
|
||||
|
||||
inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
|
||||
__m256i v = _mm256_set1_epi16(val);
|
||||
return _mm512_cvtph_ps(v);
|
||||
}
|
||||
|
||||
// horizontal reduce
|
||||
inline float _mm512_reduce_max_ps(const __m512 x) {
|
||||
__m512 v = x;
|
||||
__m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
|
||||
v = _mm512_max_ps(v, v1);
|
||||
v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
|
||||
v = _mm512_max_ps(v, v1);
|
||||
v1 = _mm512_shuffle_ps(v, v, 0x4E);
|
||||
v = _mm512_max_ps(v, v1);
|
||||
v1 = _mm512_shuffle_ps(v, v, 0xB1);
|
||||
v = _mm512_max_ps(v, v1);
|
||||
return _mm512_cvtss_f32(v);
|
||||
}
|
||||
|
||||
// transpose utils
|
||||
#define SHUFFLE_EPI32(a, b, mask) \
|
||||
_mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
|
||||
@@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \
|
||||
K, (const float *)src1->data + mb_start * K, \
|
||||
(const type *)src0->data + nb_start * K, \
|
||||
(float *)dst->data + mb_start * ldc + nb_start, ldc);
|
||||
K, (const float *)src1->data + src1_offset + mb_start * K, \
|
||||
(const type *)src0->data + src0_offset + nb_start * K, \
|
||||
(float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc)
|
||||
|
||||
|
||||
// re-organize in the format {NB, KB, TILE_SIZE}:
|
||||
@@ -2019,11 +1983,11 @@ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, B
|
||||
}
|
||||
};
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
|
||||
tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
|
||||
KB, (const char *)wdata + 0 * row_size_A, \
|
||||
(const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
|
||||
(float *) dst->data + 0 * N + nb_start, ldc)
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
|
||||
tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
|
||||
KB, wdata_batch, \
|
||||
(const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
|
||||
(float *) dst->data + dst_offset + nb_start, ldc)
|
||||
|
||||
template <typename TA, typename TB, typename TC, int BLOCK_K,
|
||||
typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>
|
||||
@@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
|
||||
_tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
|
||||
|
||||
if (need_unpack) {
|
||||
unpack_B<TB>(Tile1, B_blk0);
|
||||
unpack_B<TB>(Tile1, B_blk1);
|
||||
_tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
|
||||
} else {
|
||||
_tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
|
||||
@@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d
|
||||
});
|
||||
}
|
||||
|
||||
// ne2 is passed explicitly to help compiler optimize repeated calls
|
||||
inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) {
|
||||
const int64_t i2 = batch_idx % ne2;
|
||||
const int64_t i3 = batch_idx / ne2;
|
||||
return i3 * t->nb[3] + i2 * t->nb[2];
|
||||
}
|
||||
|
||||
size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
|
||||
struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
@@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
|
||||
|
||||
const int M = dst->ne[1];
|
||||
const int K = src0->ne[0];
|
||||
const int64_t n_batch = dst->ne[2] * dst->ne[3];
|
||||
|
||||
size_t desired_wsize = 0;
|
||||
|
||||
GGML_DISPATCH_QTYPES(TYPE, [&] {
|
||||
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
|
||||
desired_wsize = M * row_size_A;
|
||||
desired_wsize = n_batch * M * row_size_A;
|
||||
});
|
||||
|
||||
return desired_wsize;
|
||||
@@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
|
||||
// src1: input in shape of {M, K}, float32
|
||||
// dst: output in shape of {M, N}, float32
|
||||
//
|
||||
// the function performs: dst = src1 @ src0.T
|
||||
// the function performs: dst = src1 @ src0.T for each batch
|
||||
//
|
||||
void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
struct ggml_tensor * src0 = dst->src[0];
|
||||
@@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
const int K = src0->ne[0];
|
||||
const int ldc = dst->nb[1] / dst->nb[0];
|
||||
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t n_batch = ne2 * dst->ne[3];
|
||||
|
||||
if (is_floating_type) {
|
||||
constexpr int BLOCK_M = 4;
|
||||
constexpr int BLOCK_N = 6;
|
||||
const int MB = div_up(M, BLOCK_M);
|
||||
const int NB = div_up(N, BLOCK_N);
|
||||
|
||||
parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
|
||||
parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
|
||||
GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
|
||||
for (int i = begin; i < end; ++i) {
|
||||
int mb = i / NB;
|
||||
int nb = i % NB;
|
||||
int batch_idx = i / (MB * NB);
|
||||
int remaining = i % (MB * NB);
|
||||
int mb = remaining / NB;
|
||||
int nb = remaining % NB;
|
||||
|
||||
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
|
||||
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
|
||||
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
|
||||
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
@@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
void * wdata = params->wdata;
|
||||
|
||||
//TODO: performance improvement: merge quant A
|
||||
if (params->ith == 0) {
|
||||
// if (params->ith == 0) {
|
||||
GGML_DISPATCH_QTYPES(TYPE, [&] {
|
||||
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
|
||||
const size_t desired_wsize = M * row_size_A;
|
||||
const size_t desired_wsize = n_batch * M * row_size_A;
|
||||
if (params->wsize < desired_wsize) {
|
||||
GGML_ABORT("insufficient work space size");
|
||||
}
|
||||
@@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
// Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
|
||||
GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
|
||||
|
||||
const float * A_data = static_cast<const float *>(src1->data);
|
||||
for (int m = 0; m < M; ++m) {
|
||||
from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
|
||||
}
|
||||
parallel_for_ggml(params, n_batch, [&](int begin, int end) {
|
||||
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
|
||||
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
|
||||
const float * A_data = (const float *)((const char *)src1->data + src1_offset);
|
||||
char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
|
||||
|
||||
for (int m = 0; m < M; ++m) {
|
||||
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
// }
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
@@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
constexpr int BLOCK_N = TILE_N * kTilesN;
|
||||
const int NB = div_up(N, BLOCK_N);
|
||||
|
||||
parallel_for_ggml(params, NB, [&](int begin, int end) {
|
||||
parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) {
|
||||
GGML_DISPATCH_QTYPES(TYPE, [&] {
|
||||
const int KB = K / blck_size;
|
||||
const int TILE_SIZE = get_tile_size<type>();
|
||||
const int row_size_A = KB * sizeof(vec_dot_type);
|
||||
for (int i = begin; i < end; ++i) {
|
||||
int nb = i;
|
||||
int batch_idx = i / NB;
|
||||
int nb = i % NB;
|
||||
|
||||
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
|
||||
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
|
||||
const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A;
|
||||
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
|
||||
|
||||
@@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
const int MB = div_up(M, BLOCK_M);
|
||||
const int NB = div_up(N, BLOCK_N);
|
||||
|
||||
parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
|
||||
parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
|
||||
// init tile config for each thread
|
||||
ggml_tile_config_init();
|
||||
|
||||
@@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
const int row_size_A = KB * sizeof(vec_dot_type);
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
int mb = i / NB;
|
||||
int nb = i % NB;
|
||||
int batch_idx = i / (MB * NB);
|
||||
int remaining = i % (MB * NB);
|
||||
int mb = remaining / NB;
|
||||
int nb = remaining % NB;
|
||||
|
||||
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
|
||||
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
|
||||
const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A;
|
||||
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
@@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
|
||||
tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(
|
||||
mb_size, nb_size, KB,
|
||||
(const char *)wdata + mb_start * row_size_A,
|
||||
(const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
|
||||
(float *) dst->data + mb_start * N + nb_start, ldc);
|
||||
wdata_batch + mb_start * row_size_A,
|
||||
(const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
|
||||
(float *) dst->data + dst_offset + mb_start * N + nb_start, ldc);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -42,11 +42,14 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -55,11 +58,14 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
||||
@@ -67,8 +73,10 @@
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
||||
// repack.cpp
|
||||
@@ -77,19 +85,23 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__POWERPC__) || defined(__powerpc__)
|
||||
@@ -110,11 +122,14 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -123,11 +138,14 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__loongarch64)
|
||||
@@ -148,11 +166,14 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -161,25 +182,22 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__riscv)
|
||||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
|
||||
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
|
||||
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
|
||||
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
|
||||
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
||||
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||
@@ -193,11 +211,14 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -205,11 +226,14 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__s390x__)
|
||||
@@ -236,11 +260,14 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -249,11 +276,14 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__wasm__)
|
||||
@@ -282,11 +312,14 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -295,11 +328,14 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#endif
|
||||
|
||||
@@ -498,6 +498,81 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
float * res_ptr = s;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
|
||||
|
||||
float32x4_t sumf = vdupq_n_f32(0);
|
||||
for (int l = 0; l < nb; l++) {
|
||||
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
|
||||
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
|
||||
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
|
||||
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
|
||||
|
||||
int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
|
||||
int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
|
||||
int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
|
||||
int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
|
||||
int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
|
||||
int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
|
||||
int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
|
||||
int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
|
||||
|
||||
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
|
||||
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
|
||||
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
|
||||
sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
|
||||
sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
|
||||
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
|
||||
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
|
||||
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
|
||||
|
||||
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
|
||||
float32x4_t b_d = {
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
|
||||
};
|
||||
float32x4_t d = a_d * b_d;
|
||||
|
||||
sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
|
||||
}
|
||||
|
||||
vst1q_f32(res_ptr + x * 4, sumf);
|
||||
}
|
||||
return;
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
@@ -785,6 +860,165 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x4_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 1x8 tile = 2 x 4
|
||||
float32x4_t acc_f32[col_groups];
|
||||
|
||||
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
|
||||
float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||
float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d);
|
||||
float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d);
|
||||
float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
|
||||
float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
|
||||
float32x4_t sb_min_0123 = vmulq_f32(q5_dmin_0, q8_d);
|
||||
float32x4_t sb_min_4567 = vmulq_f32(q5_dmin_1, q8_d);
|
||||
|
||||
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
||||
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
||||
int32x4_t acc_lo[col_groups];
|
||||
int32x4_t acc_hi[col_groups];
|
||||
|
||||
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
||||
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
||||
int16_t bsums_arr[8];
|
||||
vst1q_s16(bsums_arr, bsums);
|
||||
|
||||
uint8x16_t qh[col_groups][8];
|
||||
for (int c = 0; c < col_groups; c++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
|
||||
}
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q5sb_mins[2];
|
||||
int16x8_t q5sb_scales[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q5sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
||||
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
||||
}
|
||||
|
||||
int8x16_t q8_qs[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
|
||||
}
|
||||
|
||||
for (int c = 0; c < col_groups; c++) {
|
||||
uint8x16_t q5_cols[8];
|
||||
uint8x16_t hbit_lo[8];
|
||||
uint8x16_t hbit_hi[8];
|
||||
int8x16_t q5_lo[8];
|
||||
int8x16_t q5_hi[8];
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
|
||||
hbit_lo[i] = vandq_u8(qh[c][i], mone);
|
||||
hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3);
|
||||
qh[c][i] = vshrq_n_u8(qh[c][i], 2);
|
||||
q5_lo[i] = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4));
|
||||
q5_hi[i] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i]));
|
||||
}
|
||||
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3);
|
||||
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3);
|
||||
}
|
||||
|
||||
// Scales
|
||||
// row c0123 blk0 and blk1
|
||||
const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
|
||||
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
|
||||
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
|
||||
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
|
||||
// row c4567 blk0 and blk1
|
||||
const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
|
||||
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
|
||||
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
|
||||
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
|
||||
|
||||
// Bias Correction
|
||||
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
||||
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
||||
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
||||
} // for sb
|
||||
|
||||
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
|
||||
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
|
||||
} // for b
|
||||
|
||||
int base = x * ncols_interleaved;
|
||||
vst1q_f32(s + base, acc_f32[0]);
|
||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||
} // for x
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
@@ -3005,6 +3239,87 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nr % 4 == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
|
||||
|
||||
float32x4_t sumf[4];
|
||||
for (int m = 0; m < 4; m++) {
|
||||
sumf[m] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
|
||||
float32x4_t b_d = {
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
|
||||
};
|
||||
|
||||
int32x4_t sumi_0 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_1 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_2 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_3 = vdupq_n_s32(0);
|
||||
|
||||
for (int k = 0; k < 4; k++) {
|
||||
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
|
||||
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
|
||||
|
||||
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
|
||||
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
|
||||
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
|
||||
|
||||
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
|
||||
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
|
||||
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
|
||||
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
|
||||
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
|
||||
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
|
||||
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
|
||||
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
|
||||
}
|
||||
|
||||
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
||||
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
||||
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
||||
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
||||
}
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
@@ -3205,6 +3520,235 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x4_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs
|
||||
constexpr int col_groups = ncols_interleaved / 4;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 8 accumulators: 2 row pairs, 4 col pairs
|
||||
float32x4_t acc_f32[acc_size];
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// d5 0 1 2 3, 4 5 6 7
|
||||
float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
|
||||
float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
|
||||
// d8 0 1 2 3
|
||||
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
||||
// mins
|
||||
float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));
|
||||
float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));
|
||||
|
||||
// Precomputation of scales and mins
|
||||
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
||||
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
||||
float32x4_t sbd_min_0123[q8_k_blocklen];
|
||||
float32x4_t sbd_min_4567[q8_k_blocklen];
|
||||
|
||||
sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0);
|
||||
sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0);
|
||||
sbd_min_0123[0] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0);
|
||||
sbd_min_4567[0] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0);
|
||||
|
||||
sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1);
|
||||
sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1);
|
||||
sbd_min_0123[1] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1);
|
||||
sbd_min_4567[1] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1);
|
||||
|
||||
sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2);
|
||||
sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2);
|
||||
sbd_min_0123[2] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2);
|
||||
sbd_min_4567[2] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2);
|
||||
|
||||
sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3);
|
||||
sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3);
|
||||
sbd_min_0123[3] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3);
|
||||
sbd_min_4567[3] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3);
|
||||
|
||||
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
|
||||
const int16x8_t bsums[q8_k_blocklen] = {
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
int16_t bsums_arr[QK_K / 64][8];
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
||||
}
|
||||
|
||||
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
|
||||
int32x4_t bias_acc[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
bias_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
uint8x16_t qh[col_groups][8];
|
||||
for (int c = 0; c < col_groups; c++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
|
||||
}
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Int accumulators for qs vecdot (4 row * 2 col quartets)
|
||||
int32x4_t acc_lo[acc_size];
|
||||
int32x4_t acc_hi[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q5sb_scales[2];
|
||||
int16x8_t q5sb_mins[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q5sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
||||
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
||||
}
|
||||
|
||||
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
|
||||
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
|
||||
|
||||
// 0..3 & 32..35
|
||||
const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
|
||||
const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
||||
|
||||
// NOTE: This is the only difference with q4_K
|
||||
const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
|
||||
const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
|
||||
qh[0][k] = vshrq_n_u8(qh[0][k], 2);
|
||||
const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
|
||||
const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
|
||||
qh[1][k] = vshrq_n_u8(qh[1][k], 2);
|
||||
// From here, same as q4_K
|
||||
|
||||
const int8x16_t q5_0123_lo =
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
|
||||
const int8x16_t q5_0123_hi =
|
||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
|
||||
|
||||
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
|
||||
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
|
||||
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
|
||||
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
|
||||
|
||||
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
|
||||
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
|
||||
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
|
||||
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
|
||||
|
||||
const int8x16_t q5_4567_lo =
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
|
||||
const int8x16_t q5_4567_hi =
|
||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
|
||||
|
||||
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
|
||||
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
|
||||
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
|
||||
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
|
||||
|
||||
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
|
||||
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
|
||||
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
|
||||
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
|
||||
}
|
||||
|
||||
// Scale and bias application
|
||||
// acc is stored interleaved to match output layout
|
||||
const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
|
||||
const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
// Bias correction
|
||||
// row c0123 blk0 and blk1
|
||||
const float32x4_t sumf_0123 =
|
||||
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
|
||||
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
|
||||
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
|
||||
|
||||
// row c4567 blk0 and blk1
|
||||
const float32x4_t sumf_4567 =
|
||||
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
|
||||
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
|
||||
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
|
||||
|
||||
// Bias
|
||||
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
|
||||
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
|
||||
|
||||
// row c0123 blk0 and blk1
|
||||
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
||||
|
||||
// row c4567 blk0 and blk1
|
||||
bias_acc[2 * row + 1] =
|
||||
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * row + 1] =
|
||||
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
||||
}
|
||||
} // for sb
|
||||
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
|
||||
acc_f32[2 * row + 1] =
|
||||
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
|
||||
}
|
||||
} // for b
|
||||
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
|
||||
@@ -1954,3 +1954,773 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
#endif
|
||||
}
|
||||
|
||||
static const uint8_t sign_gather_indices_arr[64] = {
|
||||
0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3,
|
||||
4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7
|
||||
};
|
||||
|
||||
static const uint8_t sign_bit_masks_arr[64] = {
|
||||
1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128,
|
||||
1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128
|
||||
};
|
||||
|
||||
static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
|
||||
|
||||
const block_iq2_s * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
const uint64_t * grid64 = (const uint64_t *)iq2s_grid;
|
||||
|
||||
// --- Pre-load Constants ---
|
||||
uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1};
|
||||
vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8);
|
||||
uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5};
|
||||
vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8);
|
||||
|
||||
// Constants for sign extraction
|
||||
vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);
|
||||
vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
|
||||
const uint8_t * GGML_RESTRICT qs = x[i].qs;
|
||||
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
||||
const uint8_t * GGML_RESTRICT scales = x[i].scales;
|
||||
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||
|
||||
const uint8_t * signs_ptr = qs + 32;
|
||||
|
||||
float sum_block = 0.0f;
|
||||
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
// Combine low + high bits
|
||||
vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8);
|
||||
qs += 8;
|
||||
uint16_t qh_val;
|
||||
memcpy(&qh_val, qh, 2);
|
||||
qh += 2;
|
||||
vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2);
|
||||
vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2);
|
||||
vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16);
|
||||
vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8);
|
||||
v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8);
|
||||
|
||||
// Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000
|
||||
v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8);
|
||||
vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8);
|
||||
|
||||
// Multiply by 8 to get byte offset, instead of element offset
|
||||
v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8);
|
||||
vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8);
|
||||
|
||||
// Lookup Grid using Byte Offsets
|
||||
vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8);
|
||||
|
||||
vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals);
|
||||
vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8);
|
||||
|
||||
// Load signs and generate sign mask
|
||||
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8);
|
||||
signs_ptr += 8;
|
||||
|
||||
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
|
||||
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);
|
||||
|
||||
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);
|
||||
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);
|
||||
|
||||
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
|
||||
q8 += 64;
|
||||
|
||||
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);
|
||||
vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64);
|
||||
|
||||
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
|
||||
int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16));
|
||||
int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16));
|
||||
int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16));
|
||||
int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16));
|
||||
|
||||
uint8_t sc0 = scales[0];
|
||||
uint8_t sc1 = scales[1];
|
||||
scales += 2;
|
||||
|
||||
sum_block += s0 * (2 * (sc0 & 0xF) + 1);
|
||||
sum_block += s1 * (2 * (sc0 >> 4) + 1);
|
||||
sum_block += s2 * (2 * (sc1 & 0xF) + 1);
|
||||
sum_block += s3 * (2 * (sc1 >> 4) + 1);
|
||||
}
|
||||
sumf += sum_block * combined_scale;
|
||||
}
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
|
||||
|
||||
const block_iq2_s * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
const uint64_t * grid64 = (const uint64_t *)iq2s_grid;
|
||||
|
||||
// Pre-load Constants
|
||||
vuint8m2_t v_ids = __riscv_vid_v_u8m2(32);
|
||||
vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32);
|
||||
vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32);
|
||||
vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32);
|
||||
vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32);
|
||||
uint16_t shift_qh_arr[4] = {11, 9, 7, 5};
|
||||
vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
|
||||
const uint8_t * GGML_RESTRICT qs = x[i].qs;
|
||||
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
||||
const uint8_t * GGML_RESTRICT scales = x[i].scales;
|
||||
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||
|
||||
const uint8_t * signs_ptr = qs + 32;
|
||||
float sum_block = 0.0f;
|
||||
|
||||
for (int ib = 0; ib < 8; ++ib) {
|
||||
|
||||
// Load Low Bits [4 bytes]
|
||||
vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4);
|
||||
qs += 4;
|
||||
|
||||
// Load 1 byte. It contains bits for 4 mini-blocks.
|
||||
uint8_t qh_val = *qh++;
|
||||
|
||||
// Combine Low + High bits of 10bit indices
|
||||
vuint8mf4_t v_qh_raw = __riscv_vmv_v_x_u8mf4(qh_val, 4);
|
||||
vuint16mf2_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qh_raw, 4);
|
||||
vuint16mf2_t v_qh_mf2 = __riscv_vsll_vv_u16mf2(v_qh_u16, v_shift_qh, 4);
|
||||
v_qh_mf2 = __riscv_vand_vx_u16mf2(v_qh_mf2, 0x1800, 4);
|
||||
vuint16mf2_t v_qs_u16_mf2 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 4);
|
||||
vuint16mf2_t v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16_mf2, 3, 4);
|
||||
vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_mf2, 4);
|
||||
|
||||
// Lookup Grid
|
||||
vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(__riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 4)));
|
||||
|
||||
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 4);
|
||||
signs_ptr += 4;
|
||||
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
|
||||
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32);
|
||||
|
||||
// generating sign mask
|
||||
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32);
|
||||
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32);
|
||||
|
||||
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32);
|
||||
q8 += 32;
|
||||
|
||||
// apply signs
|
||||
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative,v_q8, v_q8, 0, 32);
|
||||
vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 32);
|
||||
|
||||
// Reduction
|
||||
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
|
||||
// Reduce 0-15 (First Half)
|
||||
int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m2(v_dot, 0), v_zero, 16));
|
||||
|
||||
// Reduce 16-31 (Second Half)
|
||||
int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m2(v_dot, 1), v_zero, 16));
|
||||
|
||||
// Apply sub Scales
|
||||
uint8_t sc = *scales++;
|
||||
|
||||
sum_block += s0 * (2 * (sc & 0xF) + 1);
|
||||
sum_block += s1 * (2 * (sc >> 4) + 1);
|
||||
}
|
||||
sumf += sum_block * combined_scale;
|
||||
}
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 128:
|
||||
ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
case 256:
|
||||
ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_iq3_s * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
const uint64_t * grid64 = (const uint64_t *)iq3s_grid;
|
||||
|
||||
// --- Pre-load Constants ---
|
||||
const uint16_t qh_bit_shifts_arr[16] = {
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
|
||||
};
|
||||
vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);
|
||||
vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);
|
||||
vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = GGML_CPU_FP16_TO_FP32(x[i].d);
|
||||
const float combined_scale = d * y[i].d;
|
||||
|
||||
const uint8_t * GGML_RESTRICT qs = x[i].qs;
|
||||
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
||||
const uint8_t * GGML_RESTRICT scales = x[i].scales;
|
||||
const uint8_t * GGML_RESTRICT signs = x[i].signs;
|
||||
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||
|
||||
float sum_block = 0.0f;
|
||||
|
||||
// Loop: Process 64 weights (16 mini-blocks of 4) per iteration
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
|
||||
vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16);
|
||||
qs += 16;
|
||||
|
||||
uint16_t qh_val;
|
||||
memcpy(&qh_val, qh, 2);
|
||||
qh += 2;
|
||||
|
||||
vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16);
|
||||
// Extract bits: (qh >> i) & 1
|
||||
v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16);
|
||||
v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16);
|
||||
|
||||
vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16);
|
||||
v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16);
|
||||
v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16);
|
||||
vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16);
|
||||
|
||||
// Grid value is 4xuint8
|
||||
vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16);
|
||||
vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed);
|
||||
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8);
|
||||
signs += 8;
|
||||
|
||||
// Generate sign mask
|
||||
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
|
||||
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);
|
||||
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);
|
||||
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);
|
||||
|
||||
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
|
||||
q8 += 64;
|
||||
|
||||
// Apply Signs
|
||||
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);
|
||||
vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64);
|
||||
|
||||
// Reduction
|
||||
vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0);
|
||||
vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1);
|
||||
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
|
||||
int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32));
|
||||
int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32));
|
||||
|
||||
// Apply sub-scales
|
||||
uint8_t sc_byte = *scales++;
|
||||
int sc_lo = (sc_byte & 0xF) * 2 + 1;
|
||||
int sc_hi = (sc_byte >> 4) * 2 + 1;
|
||||
|
||||
sum_block += s_lo * sc_lo + s_hi * sc_hi;
|
||||
}
|
||||
sumf += sum_block * combined_scale;
|
||||
}
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 256:
|
||||
ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_tq1_0 * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
float sumf = 0.0f;
|
||||
uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
// First loop.
|
||||
vint32m4_t suml1;
|
||||
{
|
||||
const int vl = 32;
|
||||
vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl);
|
||||
|
||||
vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tq, 3, vl), 8, vl);
|
||||
vuint16m2_t tq1 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 3, vl), 3, vl), 8, vl);
|
||||
vuint16m2_t tq2 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 9, vl), 3, vl), 8, vl);
|
||||
vuint16m2_t tq3 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 27, vl), 3, vl), 8, vl);
|
||||
vuint16m2_t tq4 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 81, vl), 3, vl), 8, vl);
|
||||
|
||||
vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 0, vl), vl);
|
||||
vint16m2_t q81 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 32, vl), vl);
|
||||
vint16m2_t q82 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 64, vl), vl);
|
||||
vint16m2_t q83 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 96, vl), vl);
|
||||
vint16m2_t q84 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 128, vl), vl);
|
||||
|
||||
vint16m2_t sum0 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl);
|
||||
vint16m2_t sum1 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq1, 1, vl)), q81, vl);
|
||||
vint16m2_t sum2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq2, 1, vl)), q82, vl);
|
||||
vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl);
|
||||
vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl);
|
||||
|
||||
vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl);
|
||||
vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl);
|
||||
suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl);
|
||||
}
|
||||
|
||||
// Second loop.
|
||||
vint32m2_t suml2;
|
||||
{
|
||||
const int vl = 16;
|
||||
vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl);
|
||||
|
||||
vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3 * 1, vl), 8, vl);
|
||||
vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl);
|
||||
vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl);
|
||||
vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl);
|
||||
vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl);
|
||||
|
||||
vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 160, vl), vl);
|
||||
vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 176, vl), vl);
|
||||
vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 192, vl), vl);
|
||||
vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 208, vl), vl);
|
||||
vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 224, vl), vl);
|
||||
|
||||
vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);
|
||||
vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl);
|
||||
vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl);
|
||||
vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl);
|
||||
vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl);
|
||||
|
||||
vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl);
|
||||
vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl);
|
||||
suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl);
|
||||
}
|
||||
|
||||
// Third loop.
|
||||
vint32m2_t suml3;
|
||||
{
|
||||
const int vl = 16;
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, &x[i].qh[0], 4);
|
||||
// Prevent fusion with vmv.
|
||||
__asm__ __volatile__("" : "+r"(qh));
|
||||
vuint8mf2_t tq = __riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4));
|
||||
|
||||
vuint8mf2_t p = __riscv_vle8_v_u8mf2(pow, vl);
|
||||
|
||||
vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vv_u8mf2(tq, p, vl), 3, vl), 8, vl);
|
||||
|
||||
vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl);
|
||||
|
||||
vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);
|
||||
suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl);
|
||||
}
|
||||
|
||||
vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16);
|
||||
sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16);
|
||||
sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16);
|
||||
|
||||
vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16);
|
||||
sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 256:
|
||||
ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_tq2_0 * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
float sumf = 0.0f;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
int32_t sumi = 0;
|
||||
|
||||
for (size_t j = 0; j < sizeof(x[0].qs); j += 32) {
|
||||
const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32];
|
||||
const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32];
|
||||
const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32];
|
||||
const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32];
|
||||
const uint8_t* px = &x[i].qs[j];
|
||||
|
||||
size_t vlmax_16m2 = __riscv_vsetvl_e16m2(32);
|
||||
vint16m2_t vacc16 = __riscv_vmv_v_x_i16m2(0, vlmax_16m2);
|
||||
|
||||
size_t vl = __riscv_vsetvl_e8m1(32);
|
||||
|
||||
vuint8m1_t vx_u8 = __riscv_vle8_v_u8m1(px, vl);
|
||||
|
||||
vint8m1_t vy0 = __riscv_vle8_v_i8m1(py0 , vl);
|
||||
vint8m1_t vy1 = __riscv_vle8_v_i8m1(py1, vl);
|
||||
vint8m1_t vy2 = __riscv_vle8_v_i8m1(py2, vl);
|
||||
vint8m1_t vy3 = __riscv_vle8_v_i8m1(py3, vl);
|
||||
|
||||
// l=0 (bits 1:0)
|
||||
vuint8m1_t t0 = __riscv_vand_vx_u8m1(vx_u8, 0x03, vl);
|
||||
vint8m1_t vq0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t0), 1, vl);
|
||||
|
||||
// l=1 (bits 3:2)
|
||||
vuint8m1_t t1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 2, vl), 0x03, vl);
|
||||
vint8m1_t vq1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t1), 1, vl);
|
||||
|
||||
// l=2 (bits 5:4)
|
||||
vuint8m1_t t2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 4, vl), 0x03, vl);
|
||||
vint8m1_t vq2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t2), 1, vl);
|
||||
|
||||
// l=3 (bits 7:6)
|
||||
vuint8m1_t t3 = __riscv_vsrl_vx_u8m1(vx_u8, 6, vl); // No final AND needed as vsrl shifts in zeros
|
||||
vint8m1_t vq3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t3), 1, vl);
|
||||
|
||||
// 4. Multiply and accumulate
|
||||
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq0, vy0, vl);
|
||||
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq1, vy1, vl);
|
||||
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq2, vy2, vl);
|
||||
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq3, vy3, vl);
|
||||
|
||||
vlmax_16m2 = __riscv_vsetvl_e16m2(32);
|
||||
vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
vint32m1_t vred32 = __riscv_vwredsum_vs_i16m2_i32m1(vacc16, vzero32, vlmax_16m2);
|
||||
|
||||
sumi += __riscv_vmv_x_s_i32m1_i32(vred32);
|
||||
}
|
||||
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
|
||||
sumf += (float)sumi * d;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 256:
|
||||
ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_iq1_s * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
// Load qh once for the entire superblock.
|
||||
vuint16mf2_t qh = __riscv_vle16_v_u16mf2(x[i].qh, 8);
|
||||
|
||||
// Calculate ls.
|
||||
vuint16mf2_t temp = __riscv_vsrl_vx_u16mf2(qh, 12, 8);
|
||||
temp = __riscv_vand_vx_u16mf2(temp, 7, 8);
|
||||
vint32m1_t ls = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vwmulu_vx_u32m1(temp, 2, 8));
|
||||
ls = __riscv_vadd_vx_i32m1(ls, 1, 8);
|
||||
|
||||
// Calculate delta.
|
||||
vbool32_t mask = __riscv_vmseq_vx_u16mf2_b32(__riscv_vand_vx_u16mf2(qh, 0x8000, 8), 0, 8);
|
||||
vint32m1_t delta_neg = __riscv_vmv_v_x_i32m1(-1, 8);
|
||||
vint32m1_t delta_pos = __riscv_vmv_v_x_i32m1(1, 8);
|
||||
vint32m1_t delta = __riscv_vmerge_vvm_i32m1(delta_neg, delta_pos, mask, 8);
|
||||
|
||||
// Load qs.
|
||||
vuint8m1_t qs = __riscv_vle8_v_u8m1(x[i].qs, 32);
|
||||
|
||||
// Prepare the indices.
|
||||
const uint64_t shift = 0x0009000600030000;
|
||||
vuint16m2_t qh_shift = __riscv_vreinterpret_v_u64m2_u16m2(__riscv_vmv_v_x_u64m2(shift, 8));
|
||||
vuint16m2_t qh_gather_index = __riscv_vreinterpret_v_i16m2_u16m2(
|
||||
__riscv_vdiv_vx_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(32)), 4, 32));
|
||||
vuint16m2_t qh_ext = __riscv_vlmul_ext_v_u16m1_u16m2(__riscv_vlmul_ext_v_u16mf2_u16m1(qh));
|
||||
vuint16m2_t qh_index = __riscv_vrgather_vv_u16m2(qh_ext, qh_gather_index, 32);
|
||||
qh_index = __riscv_vsrl_vv_u16m2(qh_index, qh_shift, 32);
|
||||
qh_index = __riscv_vand_vx_u16m2(qh_index, 7, 32);
|
||||
qh_index = __riscv_vsll_vx_u16m2(qh_index, 8, 32);
|
||||
qh_index = __riscv_vor_vv_u16m2(qh_index, __riscv_vzext_vf2_u16m2(qs, 32), 32);
|
||||
vuint16m2_t index = __riscv_vsll_vx_u16m2(qh_index, 3, 32);
|
||||
|
||||
// Final lsums.
|
||||
int32_t lsums_s[8];
|
||||
vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
|
||||
// Sub-blocks 1-4
|
||||
{
|
||||
vuint16m1_t grid_index0 = __riscv_vget_v_u16m2_u16m1(index, 0);
|
||||
vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 16));
|
||||
vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 128);
|
||||
vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128);
|
||||
lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 0), one_scalar, 32));
|
||||
lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 1), one_scalar, 32));
|
||||
lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 2), one_scalar, 32));
|
||||
lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 3), one_scalar, 32));
|
||||
}
|
||||
__asm__ __volatile__("" ::: "memory");
|
||||
// Sub-blocks 5-8
|
||||
{
|
||||
vuint16m1_t grid_index1 = __riscv_vget_v_u16m2_u16m1(index, 1);
|
||||
vint8m4_t grid1 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index1, 16));
|
||||
vint8m4_t q81 = __riscv_vle8_v_i8m4(&y[i].qs[128], 128);
|
||||
vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(grid1, q81, 128);
|
||||
lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 0), one_scalar, 32));
|
||||
lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 1), one_scalar, 32));
|
||||
lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 2), one_scalar, 32));
|
||||
lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 3), one_scalar, 32));
|
||||
}
|
||||
__asm__ __volatile__("" ::: "memory");
|
||||
vint32m1_t lsums = __riscv_vle32_v_i32m1(&lsums_s[0], 8);
|
||||
|
||||
// Calculate the bsums.
|
||||
vint16m1_t bsums_0 = __riscv_vle16_v_i16m1(y[i].bsums, 16);
|
||||
const vuint32m1_t bsums_i32 = __riscv_vreinterpret_v_u16m1_u32m1(__riscv_vreinterpret_v_i16m1_u16m1(bsums_0));
|
||||
const vint16mf2_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 0, 8));
|
||||
const vint16mf2_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 16, 8));
|
||||
const vint32m1_t bsums = __riscv_vwadd_vv_i32m1(bsums_i32_0, bsums_i32_1, 8);
|
||||
|
||||
// Accumulation.
|
||||
vint32m1_t sumi_v = __riscv_vmul_vv_i32m1(ls, lsums, 8);
|
||||
vint32m1_t sumi1_v = __riscv_vmul_vv_i32m1(__riscv_vmul_vv_i32m1(ls, delta, 8), bsums, 8);
|
||||
|
||||
// Update sumf.
|
||||
int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));
|
||||
int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));
|
||||
sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 256:
|
||||
ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_iq1_m * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
iq1m_scale_t scale;
|
||||
float sumf = 0.0f;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const int8_t * q8 = y[i].qs;
|
||||
const uint8_t * qs = x[i].qs;
|
||||
const uint8_t * qh = x[i].qh;
|
||||
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||
|
||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||
|
||||
// Accumulators.
|
||||
vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16);
|
||||
vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16);
|
||||
|
||||
// We process 4 sub-blocks together.
|
||||
for (int ib = 0; ib < QK_K/128; ib++) {
|
||||
// Load qh for 4 sub-blocks.
|
||||
const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8);
|
||||
const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8);
|
||||
const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8);
|
||||
const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1(
|
||||
__riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16);
|
||||
qh += 8;
|
||||
|
||||
// Prepare grid indices.
|
||||
const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16);
|
||||
const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8));
|
||||
vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 16), 0x700, 16), 16);
|
||||
index = __riscv_vsll_vx_u16m1(index, 3, 16);
|
||||
qs += 16;
|
||||
|
||||
// Load the grid.
|
||||
const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4(
|
||||
__riscv_vluxei16_v_u64m4(iq1s_grid, index, 16)));
|
||||
|
||||
// Prepare the deltas.
|
||||
const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16(
|
||||
__riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16);
|
||||
const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16);
|
||||
const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16);
|
||||
const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4(
|
||||
__riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16));
|
||||
|
||||
// Load q8 for sub-blocks.
|
||||
const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128);
|
||||
q8 += 128;
|
||||
|
||||
// Calculate the lsums.
|
||||
const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 128);
|
||||
const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 128);
|
||||
|
||||
// Prepare the scales.
|
||||
const int16_t ls_0_0 = 2*((sc[0] >> 0) & 0x7) + 1;
|
||||
const int16_t ls_0_1 = 2*((sc[0] >> 3) & 0x7) + 1;
|
||||
const int16_t ls_1_0 = 2*((sc[0] >> 6) & 0x7) + 1;
|
||||
const int16_t ls_1_1 = 2*((sc[0] >> 9) & 0x7) + 1;
|
||||
const int16_t ls_2_0 = 2*((sc[1] >> 0) & 0x7) + 1;
|
||||
const int16_t ls_2_1 = 2*((sc[1] >> 3) & 0x7) + 1;
|
||||
const int16_t ls_3_0 = 2*((sc[1] >> 6) & 0x7) + 1;
|
||||
const int16_t ls_3_1 = 2*((sc[1] >> 9) & 0x7) + 1;
|
||||
sc += 2;
|
||||
|
||||
// Accumulate in acc0 and acc1 for each sub-block.
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16);
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16);
|
||||
//
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16);
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16);
|
||||
//
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16);
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16);
|
||||
//
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16);
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16);
|
||||
}
|
||||
|
||||
// Reduce and accumulate in `sumf`.
|
||||
vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 16));
|
||||
int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 16));
|
||||
sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 256:
|
||||
ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -522,7 +522,8 @@ template<typename block_tx8>
|
||||
static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
|
||||
static_assert(
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>,
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8> ||
|
||||
std::is_same_v<block_tx8, block_mxfp4x8>,
|
||||
"Unsupported block type");
|
||||
|
||||
const int qk = QK8_0;
|
||||
@@ -580,6 +581,18 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
||||
col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
|
||||
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
|
||||
// Load 8 E8M0 exponents and convert to float via LUT
|
||||
// Rearranged to match changemask order: 0,4,1,5,2,6,3,7
|
||||
col_scale_f32 = _mm256_set_ps(
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
|
||||
}
|
||||
|
||||
// Load and convert to FP32 scale from block_q8_0
|
||||
@@ -628,7 +641,8 @@ template<typename block_tx8>
|
||||
static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
|
||||
static_assert(
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>,
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8> ||
|
||||
std::is_same_v<block_tx8, block_mxfp4x8>,
|
||||
"Unsupported block type");
|
||||
|
||||
const int qk = QK8_0;
|
||||
@@ -749,6 +763,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
||||
col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
||||
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
|
||||
//TODO: simd-ify
|
||||
col_scale_f32 = _mm512_set_ps(
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));
|
||||
}
|
||||
|
||||
// Process LHS in pairs of rows
|
||||
@@ -941,6 +974,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
||||
col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
||||
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
|
||||
//TODO: simd-ify
|
||||
col_scale_f32 = _mm512_set_ps(
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));
|
||||
}
|
||||
|
||||
// Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
|
||||
@@ -1123,6 +1175,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
||||
col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
|
||||
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
|
||||
col_scale_f32 = _mm256_set_ps(
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
|
||||
}
|
||||
|
||||
// Process LHS in groups of four
|
||||
@@ -1283,6 +1345,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
||||
col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
|
||||
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
|
||||
col_scale_f32 = _mm256_set_ps(
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
|
||||
}
|
||||
|
||||
// Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
|
||||
@@ -1625,6 +1697,19 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
#if defined(__AVX2__)
|
||||
__m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));
|
||||
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
||||
|
||||
gemv_q4_b32_8x8_q8_0_lut_avx<block_mxfp4x8>(n, s, bs, vx, vy, nr, nc, signextendlut);
|
||||
|
||||
return;
|
||||
#endif
|
||||
|
||||
ggml_gemv_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
@@ -3423,6 +3508,21 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
{
|
||||
__m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));
|
||||
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
||||
|
||||
gemm_q4_b32_8x8_q8_0_lut_avx<block_mxfp4x8>(n, s, bs, vx, vy, nr, nc, signextendlut);
|
||||
|
||||
return;
|
||||
}
|
||||
#endif // defined(__AVX2__) || defined(__AVX512F__)
|
||||
|
||||
ggml_gemm_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
@@ -1,333 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
typedef vector unsigned char vec_t;
|
||||
typedef __vector_quad acc_t;
|
||||
|
||||
template <typename TA>
|
||||
class tinyBLAS_Q0_PPC {
|
||||
public:
|
||||
tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const block_q8_0 *B, int64_t ldb,
|
||||
float *C, int64_t ldc,
|
||||
int ith, int nth);
|
||||
|
||||
void matmul(int64_t m, int64_t n);
|
||||
void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
|
||||
vec_t A_pack[mc*kc*2];
|
||||
vec_t B_pack[nc*kc*2];
|
||||
int comparray[mc*kc];
|
||||
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
int64_t ytiles = m / mc;
|
||||
int64_t xtiles = n / nc;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (end > tiles) {
|
||||
end = tiles;
|
||||
}
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = (job / xtiles) * mc;
|
||||
int64_t jj = (job % xtiles) * nc;
|
||||
for (int64_t kk = 0; kk < k; kk += kc) {
|
||||
if constexpr(is_Ablock_q4) {
|
||||
packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
|
||||
} else {
|
||||
packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
|
||||
}
|
||||
packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
|
||||
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
|
||||
*c_ptr += *((float*)&fin_res[idx+I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ArrayType>
|
||||
inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) {
|
||||
vector signed int vec_C[4];
|
||||
vector float CA[4] = {0};
|
||||
vector float res[4] = {0};
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
|
||||
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
||||
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
||||
const vector signed char v8 = vec_splats((signed char)0x8);
|
||||
vector signed int vsum = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
c[0] = vec_and(c[1], lowMask);
|
||||
c[1] = vec_sr(c[1], v4);
|
||||
c[0] = vec_sub(c[0], v8);
|
||||
c[1] = vec_sub(c[1], v8);
|
||||
vsum = vec_sum4s(c[0], vsum);
|
||||
vsum2 = vec_sum4s(c[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template <typename V1, typename V2>
|
||||
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
|
||||
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
||||
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
||||
V2 t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
vector unsigned char xor_vector;
|
||||
uint8_t flip_vec = 0x80;
|
||||
xor_vector = vec_splats(flip_vec);
|
||||
t1 = vec_perm(s1, s2, swiz1);
|
||||
t2 = vec_perm(s1, s2, swiz2);
|
||||
t3 = vec_perm(s3, s4, swiz1);
|
||||
t4 = vec_perm(s3, s4, swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset+16);
|
||||
vec_xst(t7, 0, vecOffset+32);
|
||||
vec_xst(t8, 0, vecOffset+48);
|
||||
}
|
||||
|
||||
template<int RM, int RN>
|
||||
inline void kernel(int64_t ii, int64_t jj) {
|
||||
if constexpr(RM == 4 && RN == 8) {
|
||||
KERNEL_4x8(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 4) {
|
||||
KERNEL_8x4(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 8) {
|
||||
KERNEL_8x8(ii,jj);
|
||||
} else {
|
||||
assert(false && "RN/RM values not supported");
|
||||
}
|
||||
}
|
||||
template<int size>
|
||||
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray);
|
||||
template<typename VA, typename VB>
|
||||
void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip);
|
||||
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n);
|
||||
void KERNEL_4x8(int64_t ii, int64_t jj);
|
||||
void KERNEL_8x4(int64_t ii, int64_t jj);
|
||||
void KERNEL_8x8(int64_t ii, int64_t jj);
|
||||
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN);
|
||||
template <int RM, int RN>
|
||||
void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n);
|
||||
|
||||
void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
|
||||
for (int I = 0; I<8; I++) {
|
||||
float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
|
||||
*((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q8_elements(const int8_t *qs, int *ca) {
|
||||
vector signed char c1 = vec_xl(0, qs);
|
||||
vector signed char c2 = vec_xl(16, qs);
|
||||
vector signed int vsum1 = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
vsum1 = vec_sum4s(c1, vsum1);
|
||||
vsum2 = vec_sum4s(c2, vsum2);
|
||||
vector signed int vsum = vec_add(vsum1, vsum2);
|
||||
*ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template<typename VA, typename VB>
|
||||
void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
|
||||
int64_t i, j;
|
||||
block_q8_0 *aoffset = NULL;
|
||||
VA *vecOffset = NULL;
|
||||
block_q8_0* aoffsets[8];
|
||||
__vector_pair arr[8];
|
||||
VB c[8][2] = {0};
|
||||
VB c1[8] = {0}; VB c2[8] = {0};
|
||||
aoffset = const_cast<block_q8_0*>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
int index = 0;
|
||||
if (j > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 8; it++)
|
||||
aoffsets[it] = aoffset + it*lda;
|
||||
aoffset += 8 * lda;
|
||||
for (int blk = 0; blk < kc; blk++) {
|
||||
for (int it = 0; it < 8; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
if (comparray){
|
||||
process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
|
||||
}
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
|
||||
vecOffset += 256;
|
||||
}
|
||||
j--;
|
||||
index += 8*kc;
|
||||
} while(j > 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
|
||||
int64_t i, j;
|
||||
TA *aoffset = NULL;
|
||||
int8_t *vecOffset = NULL;
|
||||
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
||||
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
||||
aoffset = const_cast<TA*>(a);
|
||||
vecOffset = vec;
|
||||
int index = 0;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
do {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset5 = aoffset4 + lda;
|
||||
aoffset6 = aoffset5 + lda;
|
||||
aoffset7 = aoffset6 + lda;
|
||||
aoffset8 = aoffset7 + lda;
|
||||
aoffset += 8 * lda;
|
||||
for (int blk = 0; blk < kc; blk++) {
|
||||
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs));
|
||||
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs));
|
||||
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs));
|
||||
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs));
|
||||
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs));
|
||||
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs));
|
||||
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs));
|
||||
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[index + 8*blk+0]);
|
||||
process_q4_elements(c2, &comparray[index + 8*blk+1]);
|
||||
process_q4_elements(c3, &comparray[index + 8*blk+2]);
|
||||
process_q4_elements(c4, &comparray[index + 8*blk+3]);
|
||||
process_q4_elements(c5, &comparray[index + 8*blk+4]);
|
||||
process_q4_elements(c6, &comparray[index + 8*blk+5]);
|
||||
process_q4_elements(c7, &comparray[index + 8*blk+6]);
|
||||
process_q4_elements(c8, &comparray[index + 8*blk+7]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
|
||||
vecOffset += 256;
|
||||
}
|
||||
j--;
|
||||
index += 8*kc;
|
||||
} while (j > 0);
|
||||
}
|
||||
}
|
||||
|
||||
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
|
||||
acc_t acc[8];
|
||||
for (int i = 0; i < mc ; i += 8) {
|
||||
for (int j = 0; j < nc; j += 8) {
|
||||
vector float fin_res[16] = {0};
|
||||
vector float vs[16] = {0};
|
||||
for (int64_t kk = 0; kk < kc; kk+=2) {
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xxsetaccz(&acc[x]);
|
||||
}
|
||||
int A_block_idx = (i/8)*(16*kc) + kk*16;
|
||||
int B_block_idx = (j/8)*(16*kc)+ kk*16;
|
||||
vec_t *A_block = &vec_A[A_block_idx];
|
||||
vec_t *B_block = &vec_B[B_block_idx];
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]);
|
||||
}
|
||||
compute_scale(ii+i, jj+j, l+kk, vs);
|
||||
int c_index = (i/8)*(8*kc)+ kk*8;
|
||||
int* c_block = &comparray[c_index];
|
||||
compute(&acc[0], 0, 0, c_block, vs, fin_res);
|
||||
compute(&acc[1], 4, 4, c_block, vs, fin_res);
|
||||
compute(&acc[2], 0, 8, c_block, vs, fin_res);
|
||||
compute(&acc[3], 4, 12, c_block, vs, fin_res);
|
||||
|
||||
A_block_idx = (i/8)*(16*kc) + (kk+1)*16;
|
||||
B_block_idx = (j/8)*(16*kc)+ (kk+1)*16;
|
||||
A_block = &vec_A[A_block_idx];
|
||||
B_block = &vec_B[B_block_idx];
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]);
|
||||
}
|
||||
compute_scale(ii+i, jj+j, l+kk+1, vs);
|
||||
c_index = (i/8)*(8*kc)+ (kk+1)*8;
|
||||
c_block = &comparray[c_index];
|
||||
compute(&acc[4], 0, 0, c_block, vs, fin_res);
|
||||
compute(&acc[5], 4, 4, c_block, vs, fin_res);
|
||||
compute(&acc[6], 0, 8, c_block, vs, fin_res);
|
||||
compute(&acc[7], 4, 12, c_block, vs, fin_res);
|
||||
|
||||
}
|
||||
if (l == 0) {
|
||||
save_res(ii+i, jj+j, 0, fin_res);
|
||||
save_res(ii+i+4, jj+j, 4, fin_res);
|
||||
save_res(ii+i, jj+j+4, 8, fin_res);
|
||||
save_res(ii+i+4, jj+j+4, 12, fin_res);
|
||||
} else {
|
||||
add_save_res(ii+i, jj+j, 0, fin_res);
|
||||
add_save_res(ii+i+4, jj+j, 4, fin_res);
|
||||
add_save_res(ii+i, jj+j+4, 8, fin_res);
|
||||
add_save_res(ii+i+4, jj+j+4, 12, fin_res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const TA *const A;
|
||||
const block_q8_0 *const B;
|
||||
float *C;
|
||||
const int64_t k;
|
||||
int64_t kc;
|
||||
const int64_t lda;
|
||||
const int64_t ldb;
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
};
|
||||
@@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
|
||||
#endif
|
||||
|
||||
#if defined(__MMA__)
|
||||
#include "sgemm-ppc.h"
|
||||
typedef vector unsigned char vec_t;
|
||||
typedef __vector_quad acc_t;
|
||||
#endif
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// VECTORIZED FUSED MULTIPLY ADD
|
||||
@@ -2153,7 +2154,7 @@ class tinyBLAS_HP16_PPC {
|
||||
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
||||
for (int x = 0; x < 4; x++) {
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
|
||||
}
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
@@ -2301,43 +2302,299 @@ class tinyBLAS_HP16_PPC {
|
||||
const int nth;
|
||||
};
|
||||
|
||||
template <typename TA>
|
||||
tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const block_q8_0 *B, int64_t ldb,
|
||||
float *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
template <typename TA>
|
||||
class tinyBLAS_Q0_PPC {
|
||||
public:
|
||||
tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA * A, int64_t lda,
|
||||
const block_q8_0 * B, int64_t ldb,
|
||||
float * C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
kc = 64;
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
|
||||
int mc = 64; int nc = 64;
|
||||
if (n % 8 == 0 && n < nc) {
|
||||
nc = n;
|
||||
mc = 32 ;
|
||||
kc = 32;
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
const int64_t mc = 64;
|
||||
const int64_t kc = 64;
|
||||
int64_t nc = 64;
|
||||
int64_t n_aligned = 0;
|
||||
if (n % 64 == 0) {
|
||||
n_aligned = n;
|
||||
} else if (n == 4) {
|
||||
n_aligned = 4;
|
||||
} else if (n < 64) {
|
||||
n_aligned = (n / 8) * 8;
|
||||
} else {
|
||||
n_aligned = (n / 64) * 64;
|
||||
}
|
||||
const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
|
||||
if (is_aligned) {
|
||||
this->matmul_tiled_q0(m, n, mc, nc, kc);
|
||||
|
||||
if (n_aligned > 0) {
|
||||
if (n_aligned % 64 == 0) nc = 64;
|
||||
else if (n_aligned == n) nc = n;
|
||||
else if (n_aligned % 32 == 0) nc = 32;
|
||||
else if (n_aligned % 24 == 0) nc = 24;
|
||||
else if (n_aligned % 16 == 0) nc = 16;
|
||||
else nc = 8;
|
||||
}
|
||||
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
|
||||
if (can_use_tiled) {
|
||||
matmul_tiled(m, n_aligned, mc, nc, kc);
|
||||
if (n > n_aligned) {
|
||||
mnpack(0, m, n_aligned, n);
|
||||
}
|
||||
} else {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<int size>
|
||||
void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
|
||||
private:
|
||||
inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int I = 0; I < 4; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int I = 0; I < 4; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
|
||||
*c_ptr += *((float *)&vec_C[I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ArrayType>
|
||||
inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
|
||||
vector signed int vec_C[4];
|
||||
vector float CA[4] = {0};
|
||||
vector float res[4] = {0};
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
|
||||
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
||||
fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
||||
const vector signed char v8 = vec_splats((signed char)0x8);
|
||||
vector signed int vsum = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
c[0] = vec_and(c[1], lowMask);
|
||||
c[1] = vec_sr(c[1], v4);
|
||||
c[0] = vec_sub(c[0], v8);
|
||||
c[1] = vec_sub(c[1], v8);
|
||||
vsum = vec_sum4s(c[0], vsum);
|
||||
vsum2 = vec_sum4s(c[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template <typename V1, typename V2>
|
||||
inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
|
||||
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
||||
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
||||
V2 t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
vector unsigned char xor_vector;
|
||||
uint8_t flip_vec = 0x80;
|
||||
xor_vector = vec_splats(flip_vec);
|
||||
t1 = vec_perm(s1, s2, swiz1);
|
||||
t2 = vec_perm(s1, s2, swiz2);
|
||||
t3 = vec_perm(s3, s4, swiz1);
|
||||
t4 = vec_perm(s3, s4, swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset + 16);
|
||||
vec_xst(t7, 0, vecOffset + 32);
|
||||
vec_xst(t8, 0, vecOffset + 48);
|
||||
}
|
||||
|
||||
inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0x0F);
|
||||
const vector signed char v8 = vec_splats((signed char)0x08);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)4);
|
||||
lo = vec_and(packed, lowMask);
|
||||
hi = vec_sr(packed, v4);
|
||||
lo = vec_sub(lo, v8);
|
||||
hi = vec_sub(hi, v8);
|
||||
}
|
||||
|
||||
inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
|
||||
vec_t t[8], s[8];
|
||||
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
||||
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
||||
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
for (int i = 0; i < 4; i += 2) {
|
||||
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
||||
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
||||
}
|
||||
for (int i = 4; i < 8; i += 2) {
|
||||
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
||||
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
||||
}
|
||||
s[0] = vec_perm(t[0], t[2], swiz3);
|
||||
s[1] = vec_perm(t[0], t[2], swiz4);
|
||||
s[2] = vec_perm(t[1], t[3], swiz3);
|
||||
s[3] = vec_perm(t[1], t[3], swiz4);
|
||||
s[4] = vec_perm(t[4], t[6], swiz3);
|
||||
s[5] = vec_perm(t[4], t[6], swiz4);
|
||||
s[6] = vec_perm(t[5], t[7], swiz3);
|
||||
s[7] = vec_perm(t[5], t[7], swiz4);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
|
||||
vector signed short i16_hi = vec_unpackh(raw);
|
||||
vector signed short i16_lo = vec_unpackl(raw);
|
||||
|
||||
vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
|
||||
vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
|
||||
vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
|
||||
vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
|
||||
out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
|
||||
out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
|
||||
}
|
||||
|
||||
void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
unsigned char * vecOffset = vec;
|
||||
for (int i = 0; i < rows; i += 8) {
|
||||
const block_q4_0 * rows_base[8];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
rows_base[r] = a + (i + r) * lda;
|
||||
}
|
||||
for (int blk = 0; blk < blocks; blk++) {
|
||||
vector unsigned short hp_res[8][4];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
const block_q4_0 * current_blk = rows_base[r] + blk;
|
||||
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
|
||||
vector signed char v_qs = reinterpret_cast<vector signed char>(vec_xl(0, current_blk->qs));
|
||||
vector signed char c1, c2;
|
||||
unpack_q4_to_q8(v_qs, c1, c2);
|
||||
convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
|
||||
convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
|
||||
}
|
||||
for (int c = 0; c < 4; c++) {
|
||||
vector unsigned char c_arr[8];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
c_arr[r] = (vector unsigned char)hp_res[r][c];
|
||||
}
|
||||
vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
|
||||
vecOffset += 128;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int chunk_size>
|
||||
static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
unsigned char * vecOffset = vec;
|
||||
const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
||||
const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
||||
const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
|
||||
for (int i = 0; i < rows; i += chunk_size) {
|
||||
const block_q8_0 * rows_base[chunk_size];
|
||||
for (int r = 0; r < chunk_size; r++) {
|
||||
rows_base[r] = a + (i + r) * lda;
|
||||
}
|
||||
for (int blk = 0; blk < blocks; blk++) {
|
||||
vector unsigned short hp_res[chunk_size][4];
|
||||
for (int r = 0; r < chunk_size; r++) {
|
||||
const block_q8_0 * b = rows_base[r] + blk;
|
||||
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
|
||||
vector signed char c[2];
|
||||
__vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
|
||||
__builtin_vsx_disassemble_pair(c, & pair);
|
||||
convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
|
||||
convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
|
||||
}
|
||||
for (int col = 0; col < 4; col++) {
|
||||
if constexpr (chunk_size == 8) {
|
||||
vec_t t[8];
|
||||
t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
||||
t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
||||
t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
||||
t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
||||
t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
|
||||
t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
|
||||
t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
|
||||
t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
|
||||
|
||||
vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
|
||||
vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
|
||||
vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
|
||||
vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
|
||||
vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
|
||||
vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
|
||||
vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
|
||||
vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
|
||||
vecOffset += 128;
|
||||
} else {
|
||||
vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
||||
vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
||||
vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
||||
vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
||||
|
||||
vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
|
||||
vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
|
||||
vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
|
||||
vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
|
||||
vecOffset += 64;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
if (rows == 4) {
|
||||
pack_q8_block<4>(a, lda, rows, blocks, vec);
|
||||
} else {
|
||||
pack_q8_block<8>(a, lda, rows, blocks, vec);
|
||||
}
|
||||
}
|
||||
|
||||
template<int size>
|
||||
void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {
|
||||
int64_t i, j;
|
||||
TA *aoffset = NULL;
|
||||
int8_t *vecOffset = NULL;
|
||||
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||
TA * aoffset = NULL;
|
||||
int8_t * vecOffset = NULL;
|
||||
TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
|
||||
TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
|
||||
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
||||
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
||||
aoffset = const_cast<TA*>(a);
|
||||
aoffset = const_cast<TA *>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
@@ -2363,18 +2620,18 @@ class tinyBLAS_HP16_PPC {
|
||||
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
|
||||
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c5, &comparray[4]);
|
||||
process_q4_elements(c6, &comparray[5]);
|
||||
process_q4_elements(c7, &comparray[6]);
|
||||
process_q4_elements(c8, &comparray[7]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
process_q4_elements(c5, & comparray[4]);
|
||||
process_q4_elements(c6, & comparray[5]);
|
||||
process_q4_elements(c7, & comparray[6]);
|
||||
process_q4_elements(c8, & comparray[7]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2405,12 +2662,12 @@ class tinyBLAS_HP16_PPC {
|
||||
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
|
||||
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2434,12 +2691,12 @@ class tinyBLAS_HP16_PPC {
|
||||
case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
|
||||
break;
|
||||
}
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2450,39 +2707,38 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<typename VA, typename VB>
|
||||
void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
||||
void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
|
||||
int64_t i, j;
|
||||
block_q8_0 *aoffset = NULL;
|
||||
VA *vecOffset = NULL;
|
||||
block_q8_0* aoffsets[8];
|
||||
block_q8_0 * aoffset = NULL;
|
||||
VA * vecOffset = NULL;
|
||||
block_q8_0 * aoffsets[8];
|
||||
__vector_pair arr[8];
|
||||
VB c[8][2] = {0};
|
||||
VB c1[8] = {0}; VB c2[8] = {0};
|
||||
aoffset = const_cast<block_q8_0*>(a);
|
||||
aoffset = const_cast<block_q8_0 *>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
do {
|
||||
aoffsets[0] = aoffset;
|
||||
for (int it = 1; it < 8; it++)
|
||||
aoffsets[it] = aoffsets[it-1] + lda;
|
||||
aoffsets[it] = aoffsets[it - 1] + lda;
|
||||
aoffset += 8 * lda;
|
||||
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 8; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
|
||||
for (int it = 0; it < 8; it++)
|
||||
aoffsets[it] += lda;
|
||||
vecOffset += 256;
|
||||
@@ -2501,13 +2757,13 @@ class tinyBLAS_HP16_PPC {
|
||||
if (i > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 4; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
for (int it = 0; it < 4; it++) {
|
||||
aoffsets[it] += lda;
|
||||
}
|
||||
@@ -2520,24 +2776,24 @@ class tinyBLAS_HP16_PPC {
|
||||
if (rows & 3) {
|
||||
aoffsets[0] = aoffset;
|
||||
for (int it = 1; it < 3; it++ )
|
||||
aoffsets[it] = aoffsets[it-1] + lda;
|
||||
aoffsets[it] = aoffsets[it - 1] + lda;
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
do {
|
||||
switch(rows) {
|
||||
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[2], &arr[2]);
|
||||
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[2], & arr[2]);
|
||||
c1[2] = c[2][0]; c2[2] = c[2][1];
|
||||
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[1], &arr[1]);
|
||||
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[1], & arr[1]);
|
||||
c1[1] = c[1][0]; c2[1] = c[1][1];
|
||||
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[0], &arr[0]);
|
||||
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[0], & arr[0]);
|
||||
c1[0] = c[0][0]; c2[0] = c[0][1];
|
||||
break;
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
for (int it = 0; it < 3; it++)
|
||||
aoffsets[it] += lda;
|
||||
vecOffset += 128;
|
||||
@@ -2547,8 +2803,7 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int m_rem = MIN(m - m0, 16);
|
||||
int n_rem = MIN(n - n0, 16);
|
||||
|
||||
@@ -2585,8 +2840,7 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[8], vec_B[16] = {0};
|
||||
acc_t acc_0, acc_1;
|
||||
std::array<int, 4> comparray {};
|
||||
@@ -2594,26 +2848,26 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[8] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
|
||||
}
|
||||
for (int I = 0; I<4; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
||||
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
*((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2624,15 +2878,14 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 0, 4, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 0, 4, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii, jj+4, 4, fin_res);
|
||||
save_res(ii, jj + 4, 4, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[16], vec_B[8] = {0};
|
||||
acc_t acc_0, acc_1;
|
||||
std::array<int, 8> comparray {};
|
||||
@@ -2640,25 +2893,25 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[8] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
||||
}
|
||||
for (int I = 0; I<8; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
for (int I = 0; I < 8; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2669,15 +2922,14 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii+4, jj, 4, fin_res);
|
||||
save_res(ii + 4, jj, 4, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[16], vec_B[16] = {0};
|
||||
acc_t acc_0, acc_1, acc_2, acc_3;
|
||||
acc_t acc_4, acc_5, acc_6, acc_7;
|
||||
@@ -2686,30 +2938,30 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[16] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(&acc_2);
|
||||
__builtin_mma_xxsetaccz(&acc_3);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_2);
|
||||
__builtin_mma_xxsetaccz(& acc_3);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
|
||||
}
|
||||
for (int I = 0; I<8; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
||||
for (int I = 0; I < 8 ; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
*((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2720,19 +2972,99 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(&acc_2, 0, 8, comparray, vs, fin_res);
|
||||
compute(&acc_3, 4, 12, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(& acc_2, 0, 8, comparray, vs, fin_res);
|
||||
compute(& acc_3, 4, 12, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii+4, jj, 4, fin_res);
|
||||
save_res(ii, jj+4, 8, fin_res);
|
||||
save_res(ii+4, jj+4, 12, fin_res);
|
||||
save_res(ii + 4, jj, 4, fin_res);
|
||||
save_res(ii, jj + 4, 8, fin_res);
|
||||
save_res(ii + 4, jj + 4, 12, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
|
||||
acc_t acc[8];
|
||||
for (int i = 0; i < mc ; i += 16) {
|
||||
for (int j = 0; j < nc; j += 8) {
|
||||
int A0_base = (i / 16) * (2 * 32 * kc);
|
||||
int B0_base = (j / 8) * (32 * kc);
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xxsetaccz(&acc[x]);
|
||||
}
|
||||
for (int64_t kk = 0; kk < kc; kk++) {
|
||||
int A0_block_idx = A0_base + kk * 32;
|
||||
int B0_block_idx = B0_base + kk * 32;
|
||||
int A1_block_idx = A0_block_idx + 32 * kc;
|
||||
int B1_block_idx = B0_block_idx + 32 * kc;
|
||||
vec_t * A0_block = & vec_A[A0_block_idx];
|
||||
vec_t * B0_block = & vec_B[B0_block_idx];
|
||||
vec_t * A1_block = & vec_A[A1_block_idx];
|
||||
for (int it = 0; it < 4; it++) {
|
||||
for (int x = 0; x < 4; x++) {
|
||||
__builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (l == 0) {
|
||||
save_acc(& acc[0], ii + i, jj + j);
|
||||
save_acc(& acc[1], ii + i, jj + j + 4);
|
||||
save_acc(& acc[2], ii + i + 4, jj + j);
|
||||
save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
||||
save_acc(& acc[4], ii + i + 8, jj + j);
|
||||
save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
||||
save_acc(& acc[6], ii + i + 12, jj + j);
|
||||
save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
||||
} else {
|
||||
add_save_acc(& acc[0], ii + i, jj + j);
|
||||
add_save_acc(& acc[1], ii + i, jj + j + 4);
|
||||
add_save_acc(& acc[2], ii + i + 4, jj + j);
|
||||
add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
||||
add_save_acc(& acc[4], ii + i + 8, jj + j);
|
||||
add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
||||
add_save_acc(& acc[6], ii + i + 12, jj + j);
|
||||
add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
|
||||
vec_t A_pack[mc * kc * 4];
|
||||
vec_t B_pack[nc * kc * 4];
|
||||
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
int64_t ytiles = m / mc;
|
||||
int64_t xtiles = n / nc;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (end > tiles) {
|
||||
end = tiles;
|
||||
}
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = (job / xtiles) * mc;
|
||||
int64_t jj = (job % xtiles) * nc;
|
||||
for (int64_t kk = 0; kk < k; kk += kc) {
|
||||
if constexpr(is_Ablock_q4) {
|
||||
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
} else {
|
||||
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
}
|
||||
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
|
||||
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
@@ -2754,32 +3086,32 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float fin_res[4] = {0};
|
||||
vector float vs[4] = {0};
|
||||
vector float CA[4] = {0};
|
||||
__builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
if (isAblock_q4) {
|
||||
packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
|
||||
for(int x = 0; x < 8; x+=4) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
|
||||
for (int x = 0; x < 8; x += 4) {
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
|
||||
}
|
||||
for (int I = 0; I<RM; I++) {
|
||||
for (int J = 0; J<RN; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||
__builtin_mma_disassemble_acc(vec_C, & acc_0);
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < RM; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2800,9 +3132,21 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<int RM, int RN>
|
||||
inline void kernel(int64_t ii, int64_t jj) {
|
||||
if constexpr(RM == 4 && RN == 8) {
|
||||
KERNEL_4x8(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 4) {
|
||||
KERNEL_8x4(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 8) {
|
||||
KERNEL_8x8(ii,jj);
|
||||
} else {
|
||||
assert(false && "RN/RM values not supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <int RM, int RN>
|
||||
NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
@@ -2814,12 +3158,20 @@ class tinyBLAS_HP16_PPC {
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = m0 + job / xtiles * RM;
|
||||
int64_t jj = n0 + job % xtiles * RN;
|
||||
this->kernel<RM, RN>(ii, jj);
|
||||
kernel<RM, RN>(ii, jj);
|
||||
}
|
||||
}
|
||||
|
||||
template class tinyBLAS_Q0_PPC<block_q4_0>;
|
||||
template class tinyBLAS_Q0_PPC<block_q8_0>;
|
||||
const TA * const A;
|
||||
const block_q8_0 * const B;
|
||||
float * C;
|
||||
const int64_t k;
|
||||
int64_t kc;
|
||||
const int64_t lda;
|
||||
const int64_t ldb;
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
};
|
||||
|
||||
class tinyBLAS_PPC {
|
||||
public:
|
||||
|
||||
@@ -450,6 +450,208 @@ static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n,
|
||||
}
|
||||
}
|
||||
|
||||
template <int M, int N>
|
||||
static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int blocklen = M;
|
||||
constexpr int ncols_interleaved = N;
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[ncols_interleaved];
|
||||
float sum_minf[ncols_interleaved];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
constexpr int scale_stride = 32;
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
|
||||
|
||||
const int qh_shift = (k / (32 / blocklen)) * 2;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * blocklen + i) % 32;
|
||||
const int qh_chunk = qh_idx / blocklen;
|
||||
const int qh_pos = qh_idx % blocklen;
|
||||
const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int M, int N>
|
||||
static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int blocklen = M;
|
||||
constexpr int ncols_interleaved = N;
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][ncols_interleaved];
|
||||
float sum_minf[4][ncols_interleaved];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
constexpr int scale_stride = 32;
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
|
||||
|
||||
const int qh_shift = (k / (32 / blocklen)) * 2;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * blocklen + i) % 32;
|
||||
const int qh_chunk = qh_idx / blocklen;
|
||||
const int qh_pos = qh_idx % blocklen;
|
||||
const int b_qh_offset =
|
||||
qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k / (32 / blocklen)) * 256 +
|
||||
(k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
@@ -803,98 +1005,12 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
|
||||
@@ -982,6 +1098,82 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert(nr == 1);
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[4];
|
||||
int sumi;
|
||||
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
||||
const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
||||
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert(nr == 1);
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
int sumi;
|
||||
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
||||
const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
||||
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
@@ -1494,107 +1686,12 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][8];
|
||||
float sum_minf[4][8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
@@ -1705,6 +1802,94 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][4];
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
||||
const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
||||
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
|
||||
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++)
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][8];
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
||||
const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
||||
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
|
||||
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++)
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q8_0_4x4_q8_0_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
@@ -2029,18 +2214,16 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in
|
||||
|
||||
const int end = QK_K * 4 / blck_size_interleave;
|
||||
|
||||
// Interleave Q5_K quants by taking 8 bytes at a time
|
||||
// Interleave Q5_K quants by taking blck_size_interleave bytes at a time
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
|
||||
}
|
||||
|
||||
// Repeat for low bits 8 bytes at a time as well, since
|
||||
// Repeat for high bits with the same chunk size, since
|
||||
// the high bits are interleaved in Q5_K and the index is
|
||||
// qh_idx = (qs_idx % 32);
|
||||
// qh_val = qh[qh_idx] >> (qs_idx / 32);
|
||||
@@ -2049,9 +2232,7 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &in[src_id].qh[src_offset], blck_size_interleave);
|
||||
}
|
||||
|
||||
// The below logic is copied over from Q4_K
|
||||
@@ -2249,7 +2430,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
|
||||
const void * GGML_RESTRICT data,
|
||||
size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 8;
|
||||
|
||||
block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data;
|
||||
@@ -2493,6 +2674,121 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
|
||||
static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) {
|
||||
block_mxfp4x4 out;
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
out.e[i] = in[i].e;
|
||||
}
|
||||
|
||||
const int end = QK_MXFP4 * 2 / blck_size_interleave;
|
||||
|
||||
if (blck_size_interleave == 4) {
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 4;
|
||||
int src_offset = (i / 4) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
|
||||
GGML_ASSERT(interleave_block == 4);
|
||||
|
||||
const block_mxfp4 * src = (const block_mxfp4 *)data;
|
||||
block_mxfp4x4 * dst = ( block_mxfp4x4 *)t->data;
|
||||
|
||||
block_mxfp4 dst_tmp[4];
|
||||
|
||||
int nrow = ggml_nrows(t);
|
||||
int nrows_interleaved = 4;
|
||||
int nblocks = t->ne[0] / QK_MXFP4;
|
||||
|
||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));
|
||||
|
||||
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
*dst++ = make_block_mxfp4x4(dst_tmp, interleave_block);
|
||||
}
|
||||
src += nrows_interleaved * nblocks;
|
||||
}
|
||||
return 0;
|
||||
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size_interleave) {
|
||||
block_mxfp4x8 out;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.e[i] = in[i].e;
|
||||
}
|
||||
|
||||
const int end = QK_MXFP4 * 4 / blck_size_interleave;
|
||||
|
||||
if (blck_size_interleave == 8) {
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
|
||||
const block_mxfp4 * src = (const block_mxfp4 *)data;
|
||||
block_mxfp4x8 * dst = ( block_mxfp4x8 *)t->data;
|
||||
|
||||
block_mxfp4 dst_tmp[8];
|
||||
|
||||
int nrow = ggml_nrows(t);
|
||||
int nrows_interleaved = 8;
|
||||
int nblocks = t->ne[0] / QK_MXFP4;
|
||||
|
||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));
|
||||
|
||||
if (t->ne[1] % nrows_interleaved != 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
*dst++ = make_block_mxfp4x8(dst_tmp, interleave_block);
|
||||
}
|
||||
src += nrows_interleaved * nblocks;
|
||||
}
|
||||
return 0;
|
||||
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::repack {
|
||||
// repack
|
||||
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
||||
@@ -2523,6 +2819,10 @@ template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * da
|
||||
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q5_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 4, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
@@ -2548,6 +2848,14 @@ template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void *
|
||||
return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_mxfp4, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_mxfp4_to_mxfp4_4_bl(t, 4, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_mxfp4, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_mxfp4_to_mxfp4_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q8_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
|
||||
}
|
||||
@@ -2591,6 +2899,10 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
||||
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -2611,6 +2923,14 @@ template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
|
||||
ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_mxfp4, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_mxfp4, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -2654,6 +2974,10 @@ template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
||||
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -2674,6 +2998,14 @@ template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
|
||||
ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_mxfp4, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_mxfp4, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -3068,6 +3400,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
// instance for Q5_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 4, 8, GGML_TYPE_Q8_K> q5_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
|
||||
|
||||
// instance for Q6_K
|
||||
@@ -3081,6 +3414,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
|
||||
|
||||
// instance for MXFP4
|
||||
static const ggml::cpu::repack::tensor_traits<block_mxfp4, 4, 4, GGML_TYPE_Q8_0> mxfp4_4x4_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_mxfp4, 8, 8, GGML_TYPE_Q8_0> mxfp4_8x8_q8_0;
|
||||
|
||||
// instance for Q8_0
|
||||
static const ggml::cpu::repack::tensor_traits<block_q8_0, 4, 4, GGML_TYPE_Q8_0> q8_0_4x4_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q8_0, 8, 4, GGML_TYPE_Q8_0> q8_0_4x8_q8_0;
|
||||
@@ -3130,6 +3467,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
return &q5_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q5_K_8x4_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q6_K) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
@@ -3152,6 +3494,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
return &iq4_nl_4x4_q8_0;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_MXFP4) {
|
||||
if (ggml_cpu_has_avx2()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &mxfp4_8x8_q8_0;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 4 == 0) {
|
||||
return &mxfp4_4x4_q8_0;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q8_0) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 4 == 0) {
|
||||
|
||||
@@ -97,6 +97,19 @@ struct block_iq4_nlx8 {
|
||||
|
||||
static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
|
||||
|
||||
struct block_mxfp4x4 {
|
||||
uint8_t e[4];
|
||||
uint8_t qs[QK_MXFP4 * 2];
|
||||
};
|
||||
static_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding");
|
||||
|
||||
struct block_mxfp4x8 {
|
||||
uint8_t e[8];
|
||||
uint8_t qs[QK_MXFP4 * 4];
|
||||
};
|
||||
static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
|
||||
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -111,22 +124,28 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -143,22 +162,28 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
@@ -1149,8 +1149,7 @@ struct ggml_cuda_graph {
|
||||
size_t num_nodes = 0;
|
||||
std::vector<cudaGraphNode_t> nodes;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool disable_due_to_too_many_updates = false;
|
||||
int number_consecutive_updates = 0;
|
||||
bool warmup_complete = false;
|
||||
std::vector<ggml_cuda_graph_node_properties> props;
|
||||
|
||||
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
|
||||
@@ -1159,21 +1158,9 @@ struct ggml_cuda_graph {
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
|
||||
std::vector<ggml_cuda_graph_node_properties> extra;
|
||||
|
||||
void record_update(bool use_graph, bool update_required) {
|
||||
if (use_graph && update_required) {
|
||||
number_consecutive_updates++;
|
||||
} else {
|
||||
number_consecutive_updates = 0;
|
||||
}
|
||||
if (number_consecutive_updates >= 4) {
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
||||
disable_due_to_too_many_updates = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_enabled() const {
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
|
||||
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -16,27 +16,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i01 = blockIdx.y;
|
||||
for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
|
||||
|
||||
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
|
||||
const int64_t ib = ibx0 + i00/qk; // block index
|
||||
const int64_t iqs = (i00%qk)/qr; // quant index
|
||||
const int64_t iybs = i00 - i00%qk; // y block start index
|
||||
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
const int64_t ib = ibx0 + i00/qk; // block index
|
||||
const int64_t iqs = (i00%qk)/qr; // quant index
|
||||
const int64_t iybs = i00 - i00%qk; // y block start index
|
||||
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
|
||||
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
|
||||
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
|
||||
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -492,7 +492,7 @@ static void dequantize_block_cuda(const void * vx, dst_t * y,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
||||
const int64_t ne0203 = ne02*ne03;
|
||||
const uint3 ne02_fdv = init_fastdiv_values(ne02);
|
||||
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
|
||||
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
|
||||
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
||||
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
|
||||
}
|
||||
@@ -628,18 +628,18 @@ static __global__ void convert_unary(
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i01 = blockIdx.y;
|
||||
|
||||
const src_t * x = (const src_t *) vx;
|
||||
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
|
||||
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
|
||||
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
|
||||
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
|
||||
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
|
||||
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
|
||||
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -649,7 +649,7 @@ static void convert_unary_cuda(const void * vx, dst_t * y,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
||||
const int64_t ne0203 = ne02*ne03;
|
||||
const uint3 ne02_fdv = init_fastdiv_values(ne02);
|
||||
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
|
||||
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
|
||||
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
||||
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
|
||||
}
|
||||
|
||||
@@ -111,6 +111,44 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
|
||||
// Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true);
|
||||
|
||||
// Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy
|
||||
// compile-time static_asserts even though the kernel guard prevents runtime execution.
|
||||
// nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.
|
||||
return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);
|
||||
}
|
||||
|
||||
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
||||
if (ampere_mma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
||||
@@ -118,6 +156,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
|
||||
if (turing_mma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
||||
}
|
||||
if (amd_mfma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
||||
}
|
||||
if (amd_wmma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
|
||||
}
|
||||
@@ -130,6 +171,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
|
||||
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
||||
#elif defined(TURING_MMA_AVAILABLE)
|
||||
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
||||
#elif defined(VOLTA_MMA_AVAILABLE)
|
||||
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
@@ -205,15 +248,15 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_cols_per_thread() {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
return 1; // RDNA has a single column.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
return 1; // AMD has a single column per thread.
|
||||
#else
|
||||
return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
static __host__ int get_cols_per_warp(const int cc) {
|
||||
if (turing_mma_available(cc) || amd_wmma_available(cc)) {
|
||||
if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
|
||||
return 16;
|
||||
} else {
|
||||
// Volta
|
||||
@@ -241,6 +284,7 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c
|
||||
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
||||
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
||||
if constexpr (use_cp_async) {
|
||||
@@ -252,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
||||
|
||||
auto load = [&] __device__ (auto n) {
|
||||
const int stride_k = WARP_SIZE >> n;
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||
const int stride_k = warp_size >> n;
|
||||
const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
||||
const int stride_i = WARP_SIZE / stride_k;
|
||||
const int stride_i = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
return;
|
||||
@@ -263,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
||||
break;
|
||||
@@ -271,7 +315,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
|
||||
}
|
||||
@@ -287,10 +331,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
} else {
|
||||
// TODO use ggml_cuda_memcpy_1
|
||||
auto load = [&] __device__ (const int n) {
|
||||
const int stride_k = WARP_SIZE >> n;
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
|
||||
const int stride_k = warp_size >> n;
|
||||
const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
|
||||
const int k0_stop = D2 - D2 % (1*stride_k);
|
||||
const int stride_i = WARP_SIZE / stride_k;
|
||||
const int stride_i = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
return;
|
||||
@@ -298,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
||||
break;
|
||||
@@ -306,7 +350,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
|
||||
}
|
||||
@@ -324,18 +368,19 @@ template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_chec
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
|
||||
const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
if constexpr (use_cp_async) {
|
||||
static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
|
||||
static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, "bad nbatch_fa");
|
||||
static_assert(!oob_check, "OOB check incompatible with cp_async");
|
||||
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
||||
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
|
||||
constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
|
||||
constexpr int stride_j = nwarps * cols_per_warp;
|
||||
|
||||
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
|
||||
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
||||
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
||||
|
||||
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
||||
@@ -357,25 +402,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
|
||||
}
|
||||
}
|
||||
} else if constexpr (nbatch_fa < 2*WARP_SIZE) {
|
||||
constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
|
||||
} else if constexpr (nbatch_fa < 2*warp_size) {
|
||||
constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
|
||||
constexpr int stride_j = nwarps * cols_per_warp;
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
||||
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
||||
|
||||
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
|
||||
const int i = threadIdx.x % (warp_size/cols_per_warp);
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
|
||||
}
|
||||
@@ -390,7 +435,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
|
||||
const int i = i0 + 2*threadIdx.x;
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
|
||||
@@ -428,7 +473,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const int jt,
|
||||
const int kb0,
|
||||
const int k_VKQ_sup) {
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = get_cols_per_thread();
|
||||
@@ -447,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const int k_VKQ_0 = kb0 * nbatch_fa;
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
||||
#else // Volta
|
||||
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
||||
@@ -500,13 +546,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
||||
} else {
|
||||
// Wide version of KQ_C is column-major
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
// AMD matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -526,13 +572,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
} else {
|
||||
// Wide version of KQ_C is column-major
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
// AMD matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -585,12 +631,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = l % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
}
|
||||
@@ -601,7 +647,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset >= 4; offset >>= 1) {
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -611,12 +657,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = l % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
|
||||
KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
|
||||
} else {
|
||||
@@ -649,12 +695,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = (l/2) % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
}
|
||||
@@ -666,6 +712,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
// Values per KQ column are spread across 4 threads:
|
||||
constexpr int offset_first = 2;
|
||||
constexpr int offset_last = 1;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
|
||||
constexpr int offset_first = 32;
|
||||
constexpr int offset_last = 16;
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
// Values per KQ column are spread across 2 threads:
|
||||
constexpr int offset_first = 16;
|
||||
@@ -677,7 +727,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#endif // defined(TURING_MMA_AVAILABLE)
|
||||
#pragma unroll
|
||||
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -687,12 +737,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = (l/2) % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
|
||||
KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
|
||||
} else {
|
||||
@@ -739,7 +789,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
const half2 KQ_max_scale_h2 = make_half2(
|
||||
KQ_max_scale[0], KQ_max_scale[0]);
|
||||
#pragma unroll
|
||||
@@ -818,7 +868,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
|
||||
@@ -830,24 +880,38 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
|
||||
#if defined(LDMATRIX_TRANS_AVAILABLE)
|
||||
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
|
||||
// Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
|
||||
// Load with transposed addressing: 4 strided half loads.
|
||||
{
|
||||
const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
|
||||
const half * xs0_h = (const half *) xs0;
|
||||
const int stride_h = stride_tile_V * 2; // stride in half units
|
||||
half * A_h = (half *) A.x;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
|
||||
}
|
||||
}
|
||||
#else
|
||||
// TODO: Try to transpose tile_V when loading gmem to smem.
|
||||
// Use mma to transpose T_A_VKQ for RDNA.
|
||||
T_A_VKQ A_trans;
|
||||
load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
mma(A, A_trans, A_identity);
|
||||
#endif // defined(TURING_MMA_AVAILABLE)
|
||||
#endif // defined(LDMATRIX_TRANS_AVAILABLE)
|
||||
if constexpr (T_B_KQ::I == 8) {
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
||||
} else {
|
||||
// Wide version of VKQ_C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
// AMD matrix C is column-major.
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -866,7 +930,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
|
||||
}
|
||||
}
|
||||
#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
|
||||
if constexpr (nstages <= 1) {
|
||||
__syncthreads(); // Only needed if tile_K == tile_V.
|
||||
@@ -879,7 +943,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
tile_Q, tile_K, tile_V, tile_mask,
|
||||
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
@@ -899,7 +963,7 @@ template<> struct mma_tile_sizes<8> {
|
||||
using T_B_VKQ = tile< 8, 8, half2>; // column-major
|
||||
using T_C_VKQ = tile<16, 4, half2>; // row-major
|
||||
};
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
template<int ncols> struct mma_tile_sizes {
|
||||
using T_A_KQ = tile<16, 8, half2>; // row-major
|
||||
using T_B_KQ = tile<16, 8, half2>; // column-major
|
||||
@@ -944,9 +1008,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int zt_gqa,
|
||||
const int kb0_start,
|
||||
const int kb0_stop) {
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
|
||||
using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
|
||||
@@ -986,7 +1051,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
||||
#else // Volta
|
||||
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
||||
@@ -1004,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
||||
const half2 scale_h2 = make_half2(scale, scale);
|
||||
#pragma unroll
|
||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
|
||||
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
||||
const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
|
||||
const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
|
||||
const int stride_jc = WARP_SIZE / stride_k;
|
||||
const int stride_jc = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
continue;
|
||||
@@ -1015,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
|
||||
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
|
||||
break;
|
||||
@@ -1027,7 +1092,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
|
||||
tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
||||
@@ -1035,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
|
||||
}
|
||||
@@ -1127,6 +1192,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// The partial sums are spread across 8/4 threads.
|
||||
constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
|
||||
constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// The partial sums are spread across 4 threads (wavefront64, 16 cols).
|
||||
constexpr int offset_first = 32;
|
||||
constexpr int offset_last = 16;
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
// The partial sums are spread across 2 threads.
|
||||
constexpr int offset_first = 16;
|
||||
@@ -1140,7 +1209,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
#pragma unroll
|
||||
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
||||
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
|
||||
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1189,7 +1258,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
||||
@@ -1249,7 +1318,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
|
||||
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
|
||||
const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
|
||||
const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
|
||||
const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
|
||||
@@ -1283,14 +1352,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// Warps with threadIdx.y % np != 0 must NOT return early.
|
||||
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
||||
|
||||
constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
|
||||
constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
|
||||
|
||||
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
||||
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
||||
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
|
||||
float2 meta[nmeta];
|
||||
#pragma unroll
|
||||
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
||||
meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
|
||||
meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
|
||||
}
|
||||
|
||||
float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
|
||||
@@ -1300,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
||||
if (offset < WARP_SIZE) {
|
||||
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
||||
if (offset < warp_size) {
|
||||
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1318,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
||||
if (offset < WARP_SIZE) {
|
||||
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
||||
if (offset < warp_size) {
|
||||
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1328,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// Write back combined meta data:
|
||||
#pragma unroll
|
||||
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
||||
if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
|
||||
if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
|
||||
// Combined KQ max scale + rowsum.
|
||||
meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
||||
meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
||||
}
|
||||
}
|
||||
|
||||
// Combined KQ max + rowsum.
|
||||
static_assert(cols_per_warp <= WARP_SIZE);
|
||||
if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
|
||||
static_assert(cols_per_warp <= warp_size);
|
||||
if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
||||
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
||||
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
||||
}
|
||||
if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
|
||||
if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
||||
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
||||
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
||||
}
|
||||
@@ -1388,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
|
||||
|
||||
#pragma unroll
|
||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
|
||||
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
||||
const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
|
||||
const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
|
||||
const int stride_jc = WARP_SIZE / stride_k;
|
||||
const int stride_jc = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
continue;
|
||||
@@ -1399,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
|
||||
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
|
||||
break;
|
||||
@@ -1417,7 +1486,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
float2 dstk_val = make_float2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
@@ -1453,7 +1522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
|
||||
jt, kb0_start, kb0_stop);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
|
||||
@@ -1480,7 +1549,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
||||
@@ -1508,10 +1577,18 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // defined(AMD_MFMA_AVAILABLE)
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
|
||||
constexpr int nwarps = nthreads / WARP_SIZE;
|
||||
constexpr int nwarps = nthreads / warp_size;
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
|
||||
@@ -1624,7 +1701,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
||||
}
|
||||
|
||||
template <int DKQ, int DV, int ncols1, int ncols2>
|
||||
@@ -1644,7 +1721,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
|
||||
|
||||
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
|
||||
const int nwarps = nthreads / warp_size_host;
|
||||
|
||||
constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
|
||||
|
||||
@@ -1694,7 +1772,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
}
|
||||
|
||||
launch_fattn<DV, ncols1, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1186,8 +1186,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
// On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.
|
||||
// However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.
|
||||
const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
if constexpr (DV == 512) {
|
||||
|
||||
@@ -440,6 +440,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
// Use MFMA flash attention for CDNA (MI100+):
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
|
||||
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
|
||||
// MMA vs tile crossover benchmarked on MI300X @ d32768:
|
||||
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)
|
||||
// hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%)
|
||||
if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) {
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
// Fall through to tile kernel for small effective batch sizes.
|
||||
}
|
||||
|
||||
// If there are no tensor cores available, use the generic tile kernel:
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
|
||||
@@ -2278,11 +2278,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
|
||||
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
|
||||
if (ggml_is_quantized(src0->type)) {
|
||||
if (ne2 <= 4) {
|
||||
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
@@ -2305,6 +2306,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
}
|
||||
}
|
||||
|
||||
// note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
|
||||
// TODO: add asserts to verify this. should work with CUDA, HIP, etc.
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(nb12 % nb11 == 0);
|
||||
@@ -2865,15 +2868,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
bool use_cuda_graph = true;
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
|
||||
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
|
||||
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
||||
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
|
||||
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
|
||||
const std::string delta_net_prefix = "dnet_add";
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@@ -2888,31 +2882,14 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
|
||||
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD &&
|
||||
node->src[1] && node->src[1]->ne[1] > 1 &&
|
||||
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
|
||||
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
|
||||
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
|
||||
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0 &&
|
||||
strncmp(node->name, delta_net_prefix.c_str(), delta_net_prefix.size()) != 0) {
|
||||
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||
// by means of matching node names. See
|
||||
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
|
||||
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
|
||||
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
|
||||
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
|
||||
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -3002,10 +2979,6 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
if (graph->instance == nullptr) {
|
||||
res = true;
|
||||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes) {
|
||||
res = true;
|
||||
@@ -3954,14 +3927,35 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
|
||||
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (graph->is_enabled()) {
|
||||
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
|
||||
const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph);
|
||||
if (graph_compatible) {
|
||||
const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
|
||||
graph->record_update(use_cuda_graph, cuda_graph_update_required);
|
||||
if (!graph->warmup_complete) {
|
||||
// Warmup: need at least 2 calls with no property change on the 2nd call
|
||||
if (!properties_changed) {
|
||||
graph->warmup_complete = true;
|
||||
GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__);
|
||||
use_cuda_graph = true;
|
||||
cuda_graph_update_required = true;
|
||||
}
|
||||
// else: properties changed or first call - execute directly (use_cuda_graph stays false)
|
||||
} else {
|
||||
// Post-warmup: normal CUDA graph operation
|
||||
if (properties_changed) {
|
||||
// Properties changed - reset warmup, execute directly until stable again
|
||||
graph->warmup_complete = false;
|
||||
GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__);
|
||||
} else {
|
||||
use_cuda_graph = true;
|
||||
cuda_graph_update_required = graph->instance == nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
|
||||
@@ -668,7 +668,7 @@ namespace ggml_cuda_mma {
|
||||
|
||||
return ret;
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
template <int I, int J>
|
||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||
tile<I, J/2, half2> ret;
|
||||
@@ -964,6 +964,34 @@ namespace ggml_cuda_mma {
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(RDNA4)
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// MFMA: FP16 input, FP32 accumulate, convert back to half2.
|
||||
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
|
||||
// Convert existing half2 accumulator to float for MFMA:
|
||||
floatx4_t acc_f32;
|
||||
{
|
||||
const halfx4_t acc_h = reinterpret_cast<const halfx4_t&>(D.x[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
acc_f32[i] = (float)acc_h[i];
|
||||
}
|
||||
}
|
||||
|
||||
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
|
||||
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
|
||||
acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0);
|
||||
|
||||
// Convert back to half2:
|
||||
{
|
||||
halfx4_t result_h;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
result_h[i] = (_Float16)acc_f32[i];
|
||||
}
|
||||
reinterpret_cast<halfx4_t&>(D.x[0]) = result_h;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||
#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
|
||||
|
||||
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||
|
||||
@@ -1749,23 +1749,6 @@ static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backe
|
||||
return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
|
||||
}
|
||||
|
||||
static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
|
||||
if (x->ne[0] != y->ne[0]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[1] != y->ne[1]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[2] != y->ne[2]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[3] != y->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
@@ -1797,43 +1780,6 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
|
||||
return opt_experimental;
|
||||
}
|
||||
|
||||
static bool hex_supported_src0_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_src1_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_src2_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_src1_type2(ggml_type t) {
|
||||
return t == GGML_TYPE_F16;
|
||||
}
|
||||
|
||||
static bool hex_supported_src1_type3(ggml_type t) {
|
||||
return t == GGML_TYPE_I32;
|
||||
}
|
||||
|
||||
static bool hex_supported_dst_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
|
||||
// TODO: support broadcast for ne[2 and 3]
|
||||
if (x->ne[0] != y->ne[0]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[2] != y->ne[2]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[3] != y->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
@@ -1919,19 +1865,19 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_src1_type(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, dst)) {
|
||||
if (!ggml_are_same_shape(src0, dst)) {
|
||||
return false;
|
||||
}
|
||||
if (!ggml_can_repeat(src1, src0)) {
|
||||
if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1943,16 +1889,16 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_src1_type(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, dst)) {
|
||||
if (!ggml_are_same_shape(src0, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1968,13 +1914,13 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, dst)) {
|
||||
if (!ggml_are_same_shape(src0, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1990,10 +1936,10 @@ static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session *
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2011,10 +1957,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2023,10 +1969,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
|
||||
}
|
||||
|
||||
if (src1) {
|
||||
if (!hex_supported_src1_type(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, src1)) {
|
||||
if (!ggml_are_same_shape(src0, src1)) {
|
||||
return false;
|
||||
}
|
||||
if (!ggml_is_contiguous(src1)) {
|
||||
@@ -2047,15 +1993,15 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
|
||||
return false; // FIXME: add support for sinks
|
||||
}
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1) {
|
||||
if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
if (src0->ne[0] != src1->ne[0]) {
|
||||
@@ -2162,17 +2108,17 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
|
||||
const struct ggml_tensor * src2 = op->src[2];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false; // FIXME: add support for GGML_TYPE_F16 for src0
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_src1_type3(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_I32) {
|
||||
return false;
|
||||
}
|
||||
if (src2) {
|
||||
if (!hex_supported_src2_type(src2->type)) {
|
||||
if (src2->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
int n_dims = op_params[1];
|
||||
|
||||
@@ -69,27 +69,45 @@
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
struct htp_act_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
// Precomputed values
|
||||
const uint8_t * data_src0;
|
||||
const uint8_t * data_src1;
|
||||
uint8_t * data_dst;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t src1_row_size;
|
||||
size_t dst_row_size;
|
||||
|
||||
size_t src0_row_size_aligned;
|
||||
size_t src1_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
|
||||
size_t src0_spad_half_size;
|
||||
size_t src1_spad_half_size;
|
||||
size_t dst_spad_half_size;
|
||||
|
||||
uint32_t block;
|
||||
uint32_t src0_nrows;
|
||||
uint32_t src0_nrows_per_thread;
|
||||
int nc;
|
||||
};
|
||||
|
||||
static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
|
||||
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
size_t dst_row_size = actx->dst_row_size;
|
||||
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
@@ -101,43 +119,34 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * restrict data_src0 = actx->data_src0;
|
||||
const uint8_t * restrict data_src1 = actx->data_src1;
|
||||
uint8_t * restrict data_dst = actx->data_dst;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
const int nc = actx->nc;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t src1_spad_half_size = actx->src1_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
@@ -196,27 +205,22 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
size_t dst_row_size = actx->dst_row_size;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -226,45 +230,36 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * restrict data_src0 = actx->data_src0;
|
||||
const uint8_t * restrict data_src1 = actx->data_src1;
|
||||
uint8_t * restrict data_dst = actx->data_dst;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
const int nc = actx->nc;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t src1_spad_half_size = actx->src1_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
|
||||
"%zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
const float alpha = ((const float *) (op_params))[2];
|
||||
const float limit = ((const float *) (op_params))[3];
|
||||
const float alpha = ((const float *) (actx->octx->op_params))[2];
|
||||
const float limit = ((const float *) (actx->octx->op_params))[3];
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
@@ -335,26 +330,22 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
|
||||
}
|
||||
|
||||
|
||||
static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble2;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size = actx->src0_row_size;
|
||||
const size_t dst_row_size = actx->dst_row_size;
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -364,25 +355,29 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
uint8_t * data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * data_src0 = actx->data_src0;
|
||||
uint8_t * data_dst = actx->data_dst;
|
||||
|
||||
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
// nc/ne0 matches.
|
||||
const int ne0_val = actx->nc; // == dst->ne[0]
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// In gelu = x*sigmoid(x*1.702)
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
@@ -408,9 +403,9 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
// gelu = x * sigmoid(1.702 * x) // current implementation
|
||||
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
@@ -435,34 +430,23 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble2;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size = actx->src0_row_size;
|
||||
const size_t dst_row_size = actx->dst_row_size;
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -472,24 +456,27 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
uint8_t * data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * data_src0 = actx->data_src0;
|
||||
uint8_t * data_dst = actx->data_dst;
|
||||
|
||||
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
const int ne0_val = actx->nc; // == dst->ne[0]
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
const int BLOCK = actx->block;
|
||||
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
@@ -515,8 +502,8 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
// silu = x * sigmoid(x)
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
@@ -544,27 +531,22 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
static const float GELU_COEF_A = 0.044715f;
|
||||
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
size_t dst_row_size = actx->dst_row_size;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -574,43 +556,34 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * restrict data_src0 = actx->data_src0;
|
||||
const uint8_t * restrict data_src1 = actx->data_src1;
|
||||
uint8_t * restrict data_dst = actx->data_dst;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
const int nc = actx->nc;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t src1_spad_half_size = actx->src1_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
@@ -678,33 +651,7 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
@@ -719,26 +666,26 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_UNARY_SILU:
|
||||
act_op_func = unary_silu_f32;
|
||||
act_op_func = (worker_callback_t)unary_silu_f32_per_thread;
|
||||
op_type = "silu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU:
|
||||
act_op_func = glu_swiglu_f32;
|
||||
act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread;
|
||||
op_type = "swiglu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU_OAI:
|
||||
act_op_func = glu_swiglu_oai_f32;
|
||||
act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread;
|
||||
op_type = "swiglu-oai-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_GELU:
|
||||
act_op_func = unary_gelu_f32;
|
||||
act_op_func = (worker_callback_t)unary_gelu_f32_per_thread;
|
||||
op_type = "gelu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_GEGLU:
|
||||
act_op_func = glu_geglu_f32;
|
||||
act_op_func = (worker_callback_t)glu_geglu_f32_per_thread;
|
||||
op_type = "geglu-f32";
|
||||
break;
|
||||
default:
|
||||
@@ -797,13 +744,58 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
|
||||
if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
return err;
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
// Prepare context
|
||||
struct htp_act_context actx;
|
||||
actx.octx = octx;
|
||||
|
||||
actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
actx.src0_row_size = src0_row_size;
|
||||
actx.src1_row_size = src1_row_size;
|
||||
actx.dst_row_size = dst_row_size;
|
||||
|
||||
actx.src0_row_size_aligned = src0_row_size_aligned;
|
||||
actx.src1_row_size_aligned = src1_row_size_aligned;
|
||||
actx.dst_row_size_aligned = dst_row_size_aligned;
|
||||
|
||||
actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2;
|
||||
actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2;
|
||||
actx.dst_spad_half_size = octx->dst_spad.size_per_thread / 2;
|
||||
|
||||
actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned;
|
||||
actx.src0_nrows = src0_nrows;
|
||||
|
||||
actx.nc = dst->ne[0];
|
||||
|
||||
// Pointers and GLU logic
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * data_src1 = (const uint8_t *) src1->data;
|
||||
|
||||
if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
|
||||
const int32_t swapped = octx->op_params[1];
|
||||
data_src1 = data_src0;
|
||||
actx.src1_row_size = actx.src0_row_size;
|
||||
|
||||
size_t nc_in_bytes = actx.nc * SIZEOF_FP32;
|
||||
if (swapped) {
|
||||
data_src0 += nc_in_bytes;
|
||||
} else {
|
||||
data_src1 += nc_in_bytes;
|
||||
}
|
||||
}
|
||||
|
||||
actx.data_src0 = data_src0;
|
||||
actx.data_src1 = data_src1;
|
||||
actx.data_dst = (uint8_t *) dst->data;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int op_activations(struct htp_ops_context * octx) {
|
||||
|
||||
@@ -15,6 +15,13 @@
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
struct get_rows_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t src1_nrows_per_thread;
|
||||
struct fastdiv_values get_rows_div_ne10;
|
||||
struct fastdiv_values get_rows_div_ne10_ne11;
|
||||
};
|
||||
|
||||
#define get_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
@@ -39,20 +46,22 @@
|
||||
\
|
||||
const uint32_t nr = ne10 * ne11 * ne12;
|
||||
|
||||
static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
struct get_rows_context * grctx = (struct get_rows_context *)data;
|
||||
struct htp_ops_context * octx = grctx->octx;
|
||||
get_rows_preamble;
|
||||
|
||||
// parallelize by src1 elements (which correspond to dst rows)
|
||||
const uint32_t dr = octx->src1_nrows_per_thread;
|
||||
const uint32_t dr = grctx->src1_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
|
||||
const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);
|
||||
const uint32_t rem = i - i12 * ne11 * ne10;
|
||||
const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
|
||||
const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);
|
||||
const uint32_t i10 = rem - i11 * ne10;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
@@ -68,12 +77,6 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
|
||||
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_get_rows(struct htp_ops_context * octx) {
|
||||
@@ -95,12 +98,14 @@ int op_get_rows(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
struct get_rows_context grctx;
|
||||
grctx.octx = octx;
|
||||
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ static inline bool dma_queue_push(dma_queue * q,
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = desc;
|
||||
|
||||
// FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
|
||||
// FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src);
|
||||
q->push_idx = (q->push_idx + 1) & q->idx_mask;
|
||||
return true;
|
||||
}
|
||||
@@ -144,11 +144,37 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
|
||||
|
||||
dptr = q->dptr[q->pop_idx];
|
||||
|
||||
// FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
|
||||
// FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
|
||||
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
|
||||
return dptr;
|
||||
}
|
||||
|
||||
static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) {
|
||||
dma_ptr dptr = { NULL };
|
||||
|
||||
if (q->push_idx == q->pop_idx) {
|
||||
return dptr;
|
||||
}
|
||||
|
||||
dptr = q->dptr[q->pop_idx];
|
||||
|
||||
// FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
|
||||
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
|
||||
return dptr;
|
||||
}
|
||||
|
||||
static inline bool dma_queue_empty(dma_queue * q) {
|
||||
return q->push_idx == q->pop_idx;
|
||||
}
|
||||
|
||||
static inline uint32_t dma_queue_depth(dma_queue * q) {
|
||||
return (q->push_idx - q->pop_idx) & q->idx_mask;
|
||||
}
|
||||
|
||||
static inline uint32_t dma_queue_capacity(dma_queue * q) {
|
||||
return q->capacity;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
@@ -44,32 +44,6 @@ struct htp_ops_context {
|
||||
uint32_t src0_nrows_per_thread;
|
||||
uint32_t src1_nrows_per_thread;
|
||||
|
||||
struct fastdiv_values src0_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src0_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src0_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values src1_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src1_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src1_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values src3_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src3_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src3_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values broadcast_rk2;
|
||||
struct fastdiv_values broadcast_rk3;
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
|
||||
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
|
||||
|
||||
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
|
||||
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
|
||||
@@ -49,62 +49,6 @@ struct htp_matmul_context {
|
||||
struct fastdiv_values mm_div_r3;
|
||||
};
|
||||
|
||||
// vdelta control to replicate first 4x fp32 values across lanes
|
||||
static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
|
||||
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
|
||||
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
|
||||
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
|
||||
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
|
||||
0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
||||
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
|
||||
};
|
||||
|
||||
// vdelta control to replicate and interleave first 8x fp32 values across lanes
|
||||
static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
|
||||
0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
|
||||
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
|
||||
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
|
||||
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
|
||||
0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
||||
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
|
||||
};
|
||||
|
||||
// vdelta control to replicate first fp32 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
|
||||
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
||||
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
|
||||
0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
|
||||
0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
|
||||
0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
|
||||
0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
};
|
||||
|
||||
// vdelta control to replicate first fp16 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = {
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
|
||||
0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
|
||||
0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
|
||||
0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
|
||||
0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
|
||||
0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
|
||||
0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
};
|
||||
|
||||
// vdelta control to replicate first fp16 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = {
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
};
|
||||
|
||||
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
|
||||
static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
|
||||
@@ -2067,10 +2011,10 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
|
||||
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
|
||||
|
||||
// Convert to QF32
|
||||
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
|
||||
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
|
||||
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
|
||||
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
|
||||
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes
|
||||
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes
|
||||
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes
|
||||
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes
|
||||
|
||||
// Combine and convert to fp16
|
||||
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
|
||||
@@ -2080,11 +2024,6 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
|
||||
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
|
||||
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16;
|
||||
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
|
||||
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
|
||||
|
||||
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
|
||||
@@ -2130,13 +2069,8 @@ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restric
|
||||
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
||||
|
||||
// Compute max and scale
|
||||
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
|
||||
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf));
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
|
||||
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
|
||||
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
|
||||
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes
|
||||
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes
|
||||
|
||||
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
@@ -2179,11 +2113,7 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric
|
||||
|
||||
// Compute max and scale
|
||||
HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
|
||||
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
|
||||
vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
|
||||
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes
|
||||
|
||||
HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-fastdiv.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
@@ -21,6 +22,9 @@
|
||||
#define HTP_ROPE_TYPE_NORMAL 0
|
||||
#define HTP_ROPE_TYPE_NEOX 2
|
||||
|
||||
#define HTP_ROPE_SPAD_NROWS 16
|
||||
#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2)
|
||||
|
||||
#define htp_rope_preamble \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
@@ -42,7 +46,7 @@
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct rope_th_ctx {
|
||||
struct htp_rope_context {
|
||||
int32_t n_dims;
|
||||
int32_t mode;
|
||||
int32_t n_ctx_orig;
|
||||
@@ -57,7 +61,19 @@ struct rope_th_ctx {
|
||||
float theta_scale;
|
||||
float corr_dims[2];
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
size_t spad_stride;
|
||||
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t dst_row_size;
|
||||
size_t src0_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
size_t theta_cache_offset;
|
||||
uint32_t src0_nrows;
|
||||
|
||||
uint64_t t_start;
|
||||
};
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
@@ -117,64 +133,23 @@ static void rope_corr_dims(int n_dims,
|
||||
dims[1] = MIN(n_dims - 1, end);
|
||||
}
|
||||
|
||||
static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
|
||||
memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
|
||||
static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
|
||||
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
|
||||
const int32_t * op_params = &octx->op_params[0];
|
||||
uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2
|
||||
|
||||
rope_ctx->n_dims = ((const int32_t *) op_params)[1];
|
||||
rope_ctx->mode = ((const int32_t *) op_params)[2];
|
||||
rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
|
||||
uint32_t he = ne / 2; // half_dims offset in elements
|
||||
uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors
|
||||
|
||||
memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
|
||||
memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
|
||||
memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
|
||||
memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
|
||||
memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
|
||||
memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
|
||||
memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i = 0; i < nvec; i += 2) {
|
||||
HVX_Vector v0 = vsrc[i/2+0];
|
||||
HVX_Vector v1 = vsrc[i/2+hv];
|
||||
|
||||
rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
|
||||
|
||||
rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
|
||||
rope_ctx->beta_slow, rope_ctx->corr_dims);
|
||||
|
||||
rope_ctx->octx = octx;
|
||||
FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
|
||||
rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
|
||||
}
|
||||
|
||||
static void hvx_calc_rope_neox_f32(const float * restrict src0,
|
||||
float * restrict dst,
|
||||
const int num_elems,
|
||||
const float * restrict theta_cache) {
|
||||
// for (int i = 0; i < num_elems; i += 2) {
|
||||
//const float cos_theta = theta_cache[i + 0];
|
||||
//const float sin_theta = theta_cache[i + 1];
|
||||
|
||||
//const float x0 = src[0];
|
||||
//const float x1 = src[num_elems/2];
|
||||
|
||||
//dst[0] = x0*cos_theta - x1*sin_theta;
|
||||
//dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
//src += 1;
|
||||
//dst += 1;
|
||||
// }
|
||||
|
||||
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
|
||||
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
|
||||
uint8_t * restrict dst_curr = (uint8_t *) dst;
|
||||
|
||||
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
|
||||
int half_size = (sizeof(float) * (num_elems / 2));
|
||||
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
|
||||
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
|
||||
|
||||
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
|
||||
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
|
||||
HVX_Vector v2 = vtheta[i+0];
|
||||
HVX_Vector v3 = vtheta[i+1];
|
||||
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
|
||||
|
||||
@@ -186,45 +161,34 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
|
||||
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
||||
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
|
||||
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
|
||||
vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4);
|
||||
vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5);
|
||||
}
|
||||
|
||||
src0_curr += VLEN;
|
||||
theta_curr += 2 * VLEN;
|
||||
dst_curr += VLEN;
|
||||
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
|
||||
const float cos_theta = theta_cache[i+0];
|
||||
const float sin_theta = theta_cache[i+1];
|
||||
float x0 = src0[i/2];
|
||||
float x1 = src0[i/2 + he];
|
||||
dst[i/2] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_calc_rope_f32(const float * restrict src0,
|
||||
float * restrict dst,
|
||||
const int num_elems,
|
||||
const float * restrict theta_cache) {
|
||||
// for (int i = 0; i < num_elems; i += 2) {
|
||||
//const float cos_theta = theta_cache[i + 0];
|
||||
//const float sin_theta = theta_cache[i + 1];
|
||||
static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
|
||||
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
|
||||
//const float x0 = src[0];
|
||||
//const float x1 = src[1];
|
||||
uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two
|
||||
|
||||
//dst[0] = x0*cos_theta - x1*sin_theta;
|
||||
//dst[1] = x0*sin_theta + x1*cos_theta;
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i = 0; i < nvec; i+=2) {
|
||||
HVX_Vector v0 = vsrc[i+0];
|
||||
HVX_Vector v1 = vsrc[i+1];
|
||||
|
||||
//src += 2;
|
||||
//dst += 2;
|
||||
// }
|
||||
|
||||
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
|
||||
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
|
||||
uint8_t * restrict dst_curr = (uint8_t *) dst;
|
||||
|
||||
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
|
||||
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
|
||||
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
|
||||
|
||||
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
|
||||
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
|
||||
HVX_Vector v2 = vtheta[i+0];
|
||||
HVX_Vector v3 = vtheta[i+1];
|
||||
|
||||
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
|
||||
@@ -239,116 +203,65 @@ static void hvx_calc_rope_f32(const float * restrict src0,
|
||||
|
||||
HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore);
|
||||
*(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
|
||||
vdst[i+0] = Q6_V_lo_W(vstore);
|
||||
vdst[i+1] = Q6_V_hi_W(vstore);
|
||||
}
|
||||
|
||||
src0_curr += 2 * VLEN;
|
||||
theta_curr += 2 * VLEN;
|
||||
dst_curr += 2 * VLEN;
|
||||
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
|
||||
const float cos_theta = theta_cache[i+0];
|
||||
const float sin_theta = theta_cache[i+1];
|
||||
float x0 = src0[i+0];
|
||||
float x1 = src0[i+1];
|
||||
dst[i+0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i+1] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||
const uint32_t ir0,
|
||||
const uint32_t ir1,
|
||||
int nth,
|
||||
int ith,
|
||||
const int opt_path) {
|
||||
struct htp_ops_context * octx = rope_ctx->octx;
|
||||
static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
|
||||
uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
|
||||
#pragma unroll(4)
|
||||
for (uint32_t i = 0; i < nr; i++) {
|
||||
float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
|
||||
float * s = (float *) (src + i * rctx->src0_row_size_aligned);
|
||||
|
||||
hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache);
|
||||
|
||||
// fill the remain channels with data from src tensor
|
||||
if (rctx->n_dims < ne0) {
|
||||
hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
|
||||
uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
|
||||
#pragma unroll(4)
|
||||
for (uint32_t i = 0; i < nr; i++) {
|
||||
float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
|
||||
float * s = (float *) (src + i * rctx->src0_row_size_aligned);
|
||||
|
||||
hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache);
|
||||
|
||||
// fill the remain channels with data from src tensor
|
||||
if (rctx->n_dims < ne0) {
|
||||
hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_rope_context * rctx = (struct htp_rope_context *) data;
|
||||
struct htp_ops_context * octx = rctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const int32_t mode = rope_ctx->mode;
|
||||
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
|
||||
|
||||
htp_rope_preamble;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
|
||||
const uint32_t i1_end = MIN(ir1, ne1);
|
||||
const int32_t half_dims = rope_ctx->n_dims / 2;
|
||||
const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
|
||||
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
const int32_t p = pos[i2];
|
||||
|
||||
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
|
||||
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
|
||||
|
||||
for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
|
||||
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
|
||||
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
|
||||
|
||||
const float * src_loc = src;
|
||||
float * dst_data_loc = dst_data;
|
||||
|
||||
if (1 == opt_path) {
|
||||
if (is_neox) {
|
||||
hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||
} else {
|
||||
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||
}
|
||||
|
||||
src_loc += rope_ctx->n_dims;
|
||||
dst_data_loc += rope_ctx->n_dims;
|
||||
} else {
|
||||
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
|
||||
const float cos_theta = wp0[i0 + 0];
|
||||
const float sin_theta = wp0[i0 + 1];
|
||||
|
||||
if (is_neox) {
|
||||
const float x0 = src_loc[0];
|
||||
const float x1 = src_loc[half_dims];
|
||||
|
||||
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
|
||||
|
||||
src_loc += 1;
|
||||
dst_data_loc += 1;
|
||||
} else {
|
||||
const float x0 = src_loc[0];
|
||||
const float x1 = src_loc[1];
|
||||
|
||||
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
|
||||
|
||||
src_loc += 2;
|
||||
dst_data_loc += 2;
|
||||
}
|
||||
}
|
||||
|
||||
src_loc += (is_neox ? half_dims : 0);
|
||||
dst_data_loc += (is_neox ? half_dims : 0);
|
||||
}
|
||||
|
||||
// TODO: use simd to speed up the remaining elements copy
|
||||
memcpy(dst_data_loc, src_loc, remain_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
|
||||
struct htp_ops_context * octx = rope_ctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_rope_preamble;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const uint32_t src0_nrows = rctx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -358,32 +271,114 @@ static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
uint64_t tt = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == hex_is_aligned((void *) dst->data, VLEN))) {
|
||||
FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
|
||||
is_aligned = 0;
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
const int32_t mode = rctx->mode;
|
||||
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
|
||||
|
||||
// VTCM setup
|
||||
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
|
||||
float * theta_cache = (float *) (src0_spad_base);
|
||||
src0_spad_base = src0_spad_base + rctx->theta_cache_offset;
|
||||
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
|
||||
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
const float * freq_factors = src2->data ? (const float *) src2->data : NULL;
|
||||
|
||||
uint32_t ir = 0;
|
||||
uint32_t prev_i2 = (uint32_t) -1;
|
||||
|
||||
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads
|
||||
if (ir < src0_start_row) { ir++; i1++; continue; }
|
||||
if (ir >= src0_end_row) goto done;
|
||||
|
||||
// Rows in this block
|
||||
const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1);
|
||||
|
||||
// Depth before prefetch
|
||||
uint32_t dma_depth = dma_queue_depth(dma_queue);
|
||||
|
||||
// FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth,
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
|
||||
// Prefetch loop
|
||||
for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) {
|
||||
pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK);
|
||||
|
||||
uint32_t pi1 = i1 + pr;
|
||||
uint32_t pir = ir + pr;
|
||||
|
||||
// Dummy DMA transaction for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0);
|
||||
|
||||
const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
|
||||
uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned;
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
|
||||
rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
|
||||
|
||||
// FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
|
||||
}
|
||||
|
||||
// Update theta cache
|
||||
if (i2 != prev_i2) {
|
||||
prev_i2 = i2;
|
||||
|
||||
const int32_t p = pos[i2];
|
||||
rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);
|
||||
|
||||
// FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
}
|
||||
|
||||
// Skip DMA transactions from prev block (if any)
|
||||
// No need to wait for these since the DMA is setup for in-order processing
|
||||
for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
|
||||
|
||||
// Compute loop
|
||||
for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) {
|
||||
// Number of rows to compute
|
||||
cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK);
|
||||
|
||||
uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src;
|
||||
uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
// FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr,
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
|
||||
if (is_neox) {
|
||||
rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
|
||||
} else {
|
||||
rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
|
||||
}
|
||||
|
||||
uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1;
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr);
|
||||
|
||||
// Prefetch more rows (if any)
|
||||
if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) {
|
||||
uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK);
|
||||
uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS;
|
||||
uint32_t pir = ir + HTP_ROPE_SPAD_NROWS;
|
||||
|
||||
const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
|
||||
rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
|
||||
|
||||
// FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
|
||||
done:
|
||||
dma_queue_flush(dma_queue);
|
||||
tt = HAP_perf_get_qtimer_count() - tt;
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
|
||||
|
||||
rope_job_f32_per_thread(rope_ctx, n, i);
|
||||
FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt));
|
||||
}
|
||||
|
||||
static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
@@ -394,17 +389,10 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
struct rope_th_ctx rope_ctx;
|
||||
const char * op_type = "rope-f32";
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_ROPE:
|
||||
op_func = rope_job_dispatcher_f32;
|
||||
op_type = "rope-f32";
|
||||
|
||||
init_rope_ctx(&rope_ctx, octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -415,49 +403,79 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
const uint32_t n_threads = octx->n_threads;
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src0_row_size;
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
|
||||
// Aligned row sizes for VTCM
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128);
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
// Calculate spad sizes per thread
|
||||
size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned;
|
||||
size_t dst_spad_per_thread = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned;
|
||||
size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread;
|
||||
|
||||
if (src2->ne[0]) {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
|
||||
"dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
} else {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
octx->dst_spad.size);
|
||||
}
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
// Check if we fit in VTCM
|
||||
size_t total_vtcm_needed = spad_per_thread * n_threads;
|
||||
if (octx->ctx->vtcm_size < total_vtcm_needed) {
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
// Assign sizes
|
||||
octx->src0_spad.size_per_thread = src0_spad_per_thread;
|
||||
octx->dst_spad.size_per_thread = dst_spad_per_thread;
|
||||
octx->src0_spad.size = n_threads * src0_spad_per_thread;
|
||||
octx->dst_spad.size = n_threads * dst_spad_per_thread;
|
||||
octx->src1_spad.size = 0;
|
||||
|
||||
// Assign pointers
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = NULL;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
// Fill context
|
||||
struct htp_rope_context rctx;
|
||||
memset(&rctx, 0, sizeof(struct htp_rope_context));
|
||||
|
||||
rctx.t_start = HAP_perf_get_qtimer_count();
|
||||
|
||||
rctx.octx = octx;
|
||||
|
||||
const int32_t * op_params = &octx->op_params[0];
|
||||
rctx.n_dims = ((const int32_t *) op_params)[1];
|
||||
rctx.mode = ((const int32_t *) op_params)[2];
|
||||
rctx.n_ctx_orig = ((const int32_t *) op_params)[4];
|
||||
|
||||
memcpy(&rctx.freq_base, (int32_t *) op_params + 5, sizeof(float));
|
||||
memcpy(&rctx.freq_scale, (int32_t *) op_params + 6, sizeof(float));
|
||||
memcpy(&rctx.ext_factor, (int32_t *) op_params + 7, sizeof(float));
|
||||
memcpy(&rctx.attn_factor, (int32_t *) op_params + 8, sizeof(float));
|
||||
memcpy(&rctx.beta_fast, (int32_t *) op_params + 9, sizeof(float));
|
||||
memcpy(&rctx.beta_slow, (int32_t *) op_params + 10, sizeof(float));
|
||||
memcpy(&rctx.sections, (int32_t *) op_params + 11, sizeof(int) * 4);
|
||||
|
||||
rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims);
|
||||
|
||||
rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims);
|
||||
|
||||
rctx.src0_row_size = src0_row_size;
|
||||
rctx.dst_row_size = dst_row_size;
|
||||
rctx.src0_row_size_aligned = src0_row_size_aligned;
|
||||
rctx.dst_row_size_aligned = dst_row_size_aligned;
|
||||
rctx.theta_cache_offset = theta_cache_size_aligned;
|
||||
|
||||
uint32_t ne0 = dst->ne[0];
|
||||
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
rctx.src0_nrows = src0_nrows;
|
||||
|
||||
FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
|
||||
rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
||||
@@ -43,11 +43,21 @@
|
||||
\
|
||||
const uint32_t nr = ne01;
|
||||
|
||||
static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
struct htp_set_rows_context {
|
||||
struct htp_ops_context * octx;
|
||||
struct fastdiv_values div_ne12;
|
||||
struct fastdiv_values div_ne11;
|
||||
uint32_t src0_nrows_per_thread;
|
||||
};
|
||||
|
||||
static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
|
||||
struct htp_ops_context * octx = srctx->octx;
|
||||
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
@@ -56,8 +66,8 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
@@ -76,15 +86,16 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
|
||||
struct htp_ops_context * octx = srctx->octx;
|
||||
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
@@ -93,8 +104,8 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
@@ -112,16 +123,6 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_set_rows(struct htp_ops_context * octx) {
|
||||
@@ -143,18 +144,20 @@ int op_set_rows(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
|
||||
octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
|
||||
struct htp_set_rows_context srctx;
|
||||
srctx.octx = octx;
|
||||
srctx.div_ne12 = init_fastdiv_values(ne12);
|
||||
srctx.div_ne11 = init_fastdiv_values(ne11);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
switch(octx->dst.type) {
|
||||
case HTP_TYPE_F32:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs);
|
||||
break;
|
||||
case HTP_TYPE_F16:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs);
|
||||
break;
|
||||
default:
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-fastdiv.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
@@ -48,7 +49,7 @@
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct softmax_th_ctx {
|
||||
struct htp_softmax_context {
|
||||
bool use_f16;
|
||||
bool use_src1;
|
||||
uint32_t n_head;
|
||||
@@ -59,28 +60,48 @@ struct softmax_th_ctx {
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
struct fastdiv_values fastdiv_ne01;
|
||||
struct fastdiv_values fastdiv_ne02;
|
||||
struct fastdiv_values fastdiv_ne12; // For mask broadcasting
|
||||
struct fastdiv_values fastdiv_ne13; // For mask broadcasting
|
||||
size_t spad_stride;
|
||||
|
||||
struct htp_ops_context * octx;
|
||||
};
|
||||
|
||||
static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
|
||||
static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
|
||||
memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
|
||||
memset(smctx, 0, sizeof(struct htp_softmax_context));
|
||||
|
||||
memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
|
||||
memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
|
||||
memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
|
||||
softmax_ctx->n_head = src0->ne[2];
|
||||
softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
|
||||
smctx->n_head = src0->ne[2];
|
||||
smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head));
|
||||
|
||||
softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
|
||||
softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
|
||||
smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2);
|
||||
smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2);
|
||||
|
||||
softmax_ctx->use_src1 = (src1->ne[0] != 0);
|
||||
softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
|
||||
smctx->use_src1 = (src1->ne[0] != 0);
|
||||
smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
|
||||
|
||||
softmax_ctx->octx = octx;
|
||||
smctx->octx = octx;
|
||||
|
||||
// Initialize fastdiv values
|
||||
const uint32_t ne01 = src0->ne[1];
|
||||
const uint32_t ne02 = src0->ne[2];
|
||||
|
||||
if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01);
|
||||
if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02);
|
||||
|
||||
const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1;
|
||||
const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1;
|
||||
|
||||
if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12);
|
||||
if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13);
|
||||
}
|
||||
|
||||
static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
|
||||
@@ -139,8 +160,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
||||
max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
|
||||
}
|
||||
|
||||
HVX_Vector v = hvx_vec_reduce_max_f32(max_vec);
|
||||
max_vec = hvx_vec_repl4(v);
|
||||
max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
@@ -154,8 +174,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
||||
v_pad[i] = v3;
|
||||
}
|
||||
|
||||
v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
|
||||
sum_vec = hvx_vec_repl4(v);
|
||||
sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes
|
||||
|
||||
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
|
||||
HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);
|
||||
@@ -183,83 +202,9 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
|
||||
return sum;
|
||||
}
|
||||
|
||||
static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
|
||||
struct htp_ops_context * octx = softmax_ctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_softmax_preamble3;
|
||||
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
|
||||
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1);
|
||||
|
||||
float * wp0 = (float *) src0_spad_data;
|
||||
float * wp1 = (float *) src1_spad_data;
|
||||
float * wp2 = (float *) dst_spad_data;
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const uint32_t i11 = i01;
|
||||
const uint32_t i12 = i02 % ne12;
|
||||
const uint32_t i13 = i03 % ne13;
|
||||
|
||||
// ALiBi
|
||||
const uint32_t h = i02; // head
|
||||
|
||||
const float slope = (softmax_ctx->max_bias > 0.0f) ?
|
||||
h < softmax_ctx->n_head_log2 ?
|
||||
powf(softmax_ctx->m0, h + 1) :
|
||||
powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
|
||||
1.0f;
|
||||
|
||||
float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
|
||||
float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
|
||||
|
||||
// broadcast the mask across rows
|
||||
__fp16 * mp_f16 = (softmax_ctx->use_src1) ?
|
||||
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
float * mp_f32 = (softmax_ctx->use_src1) ?
|
||||
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
|
||||
if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
|
||||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
|
||||
if (mp_f32) {
|
||||
if (softmax_ctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * (float) mp_f16[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else {
|
||||
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
|
||||
struct htp_ops_context * octx = softmax_ctx->octx;
|
||||
static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_softmax_context * smctx = (struct htp_softmax_context *) data;
|
||||
struct htp_ops_context * octx = smctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
@@ -268,7 +213,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
|
||||
htp_softmax_preamble3;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -291,20 +236,103 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride);
|
||||
|
||||
float * wp0 = (float *) src0_spad_data;
|
||||
float * wp1 = (float *) src1_spad_data;
|
||||
float * wp2 = (float *) dst_spad_data;
|
||||
|
||||
uint32_t prev_i2 = (uint32_t)-1;
|
||||
float slope = 1.0f;
|
||||
|
||||
for (uint32_t r = src0_start_row; r < src0_end_row; ++r) {
|
||||
uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01);
|
||||
uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01);
|
||||
uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02);
|
||||
uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02);
|
||||
|
||||
// Map to original logic indices
|
||||
// i01 = i1
|
||||
// i02 = i2
|
||||
// i03 = i3
|
||||
|
||||
const uint32_t i11 = i1;
|
||||
// const uint32_t i12 = i2 % ne12;
|
||||
// const uint32_t i13 = i3 % ne13;
|
||||
|
||||
uint32_t i12, i13;
|
||||
if (ne12 == ne02) {
|
||||
i12 = i2;
|
||||
} else {
|
||||
i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12);
|
||||
}
|
||||
|
||||
if (ne13 == ne03) {
|
||||
i13 = i3;
|
||||
} else {
|
||||
i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13);
|
||||
}
|
||||
|
||||
// ALiBi
|
||||
if (i2 != prev_i2) {
|
||||
const uint32_t h = i2; // head
|
||||
|
||||
slope = (smctx->max_bias > 0.0f) ?
|
||||
h < smctx->n_head_log2 ?
|
||||
powf(smctx->m0, h + 1) :
|
||||
powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) :
|
||||
1.0f;
|
||||
prev_i2 = i2;
|
||||
}
|
||||
|
||||
float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03);
|
||||
float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
||||
|
||||
// broadcast the mask across rows
|
||||
__fp16 * mp_f16 = (smctx->use_src1) ?
|
||||
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
float * mp_f32 = (smctx->use_src1) ?
|
||||
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
|
||||
if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) {
|
||||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
|
||||
if (mp_f32) {
|
||||
if (smctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * (float) mp_f16[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else {
|
||||
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||
smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||
ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
|
||||
struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
|
||||
softmax_job_f32_per_thread(p_softmax_ctx, n, i);
|
||||
}
|
||||
|
||||
static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
@@ -312,17 +340,12 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
struct softmax_th_ctx softmax_ctx;
|
||||
struct htp_softmax_context smctx;
|
||||
const char * op_type = "softmax-f32";
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_SOFTMAX:
|
||||
op_func = softmax_job_dispatcher_f32;
|
||||
op_type = "softmax-f32";
|
||||
|
||||
init_softmax_ctx(&softmax_ctx, octx);
|
||||
init_softmax_ctx(&smctx, octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -342,6 +365,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
|
||||
|
||||
// Use stride for calculating offset
|
||||
smctx.spad_stride = hex_round_up(src0_row_size, 128);
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (src1->ne[0]) {
|
||||
@@ -371,8 +397,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
|
||||
smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
|
||||
#define sum_rows_preamble \
|
||||
struct htp_tensor *src0 = &octx->src0;\
|
||||
struct htp_tensor *dst = &octx->dst; \
|
||||
@@ -42,53 +41,54 @@
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3]; \
|
||||
|
||||
static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
sum_rows_preamble;
|
||||
struct sum_rows_context {
|
||||
const uint8_t * src_data;
|
||||
uint8_t * dst_data;
|
||||
uint32_t ne00;
|
||||
size_t src_stride;
|
||||
size_t dst_stride;
|
||||
uint32_t rows_per_thread;
|
||||
uint32_t total_rows;
|
||||
bool opt_path;
|
||||
};
|
||||
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
const struct sum_rows_context * smctx = (const struct sum_rows_context *) data;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t rows_per_thread = smctx->rows_per_thread;
|
||||
const uint32_t total_rows = smctx->total_rows;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
const uint32_t start_row = rows_per_thread * ith;
|
||||
const uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return HTP_STATUS_OK;
|
||||
if (start_row >= end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
const size_t src_stride = smctx->src_stride;
|
||||
const size_t dst_stride = smctx->dst_stride;
|
||||
const uint32_t ne00 = smctx->ne00;
|
||||
const bool opt_path = smctx->opt_path;
|
||||
|
||||
const uint8_t * restrict data_src = (const uint8_t *) src0->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride));
|
||||
float * restrict dst_th = (float *) (smctx->dst_data + (start_row * dst_stride));
|
||||
|
||||
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
|
||||
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
|
||||
// Calculate actual number of rows for this thread
|
||||
const uint32_t n_rows = end_row - start_row;
|
||||
|
||||
for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
|
||||
const float * restrict src_local = src_th + (ir * ne00);
|
||||
for (uint32_t ir = 0; ir < n_rows; ir++) {
|
||||
const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float)));
|
||||
|
||||
if (ir + 1 < src0_nrows_per_thread) {
|
||||
hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
|
||||
if (ir + 1 < n_rows) {
|
||||
hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
if (opt_path) {
|
||||
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
|
||||
} else {
|
||||
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
|
||||
sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_sum_rows(struct htp_ops_context * octx) {
|
||||
@@ -106,10 +106,25 @@ int op_sum_rows(struct htp_ops_context * octx) {
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
|
||||
bool opt_path = false;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = true;
|
||||
}
|
||||
|
||||
struct sum_rows_context smctx = {
|
||||
.src_data = (const uint8_t *) src0->data,
|
||||
.dst_data = (uint8_t *) dst->data,
|
||||
.ne00 = ne00,
|
||||
.src_stride = nb01,
|
||||
.dst_stride = nb1,
|
||||
.rows_per_thread = rows_per_thread,
|
||||
.total_rows = src0_nrows,
|
||||
.opt_path = opt_path,
|
||||
};
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,28 @@
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
struct htp_unary_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
// Precomputed values
|
||||
const uint8_t * data_src0;
|
||||
uint8_t * data_dst;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t dst_row_size;
|
||||
|
||||
size_t src0_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
|
||||
size_t src0_spad_half_size;
|
||||
size_t dst_spad_half_size;
|
||||
|
||||
uint32_t block;
|
||||
uint32_t src0_nrows;
|
||||
uint32_t src0_nrows_per_thread;
|
||||
uint32_t nc;
|
||||
};
|
||||
|
||||
#define htp_unary_preamble \
|
||||
const uint32_t ne00 = src->ne[0]; \
|
||||
const uint32_t ne01 = src->ne[1]; \
|
||||
@@ -57,8 +79,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
||||
sum_v = hvx_vec_repl4(reduced_sum);
|
||||
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
|
||||
|
||||
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
||||
@@ -75,128 +96,95 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void scale_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void scale_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
float scale = 0.f;
|
||||
float bias = 0.f;
|
||||
memcpy(&scale, &op_params[0], sizeof(float));
|
||||
memcpy(&bias, &op_params[1], sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
|
||||
hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void rms_norm_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
float epsilon = 0.f;
|
||||
memcpy(&epsilon, op_params, sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
||||
} else {
|
||||
float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
|
||||
|
||||
const float mean = sum / row_elems;
|
||||
const float scale = 1.0f / sqrtf(mean + epsilon);
|
||||
|
||||
hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
|
||||
}
|
||||
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
static void sqr_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void sqr_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static void sqrt_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void sqrt_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
struct htp_tensor * dst,
|
||||
uint8_t * spad,
|
||||
int htp_op,
|
||||
int32_t * op_params,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
|
||||
struct htp_ops_context * octx = uctx->octx;
|
||||
const struct htp_tensor * src = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_unary_preamble;
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
int htp_op = octx->op;
|
||||
int32_t * op_params = octx->op_params;
|
||||
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const size_t src0_row_size = uctx->src0_row_size;
|
||||
const size_t dst_row_size = uctx->dst_row_size;
|
||||
|
||||
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
|
||||
|
||||
const uint32_t src0_nrows = uctx->src0_nrows;
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
@@ -208,79 +196,104 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
|
||||
is_aligned = 0;
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
const uint8_t * restrict data_src = uctx->data_src0;
|
||||
uint8_t * restrict data_dst = uctx->data_dst;
|
||||
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
|
||||
|
||||
size_t src0_spad_half_size = uctx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = uctx->dst_spad_half_size;
|
||||
|
||||
const int BLOCK = uctx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src = (const uint8_t *) src->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
|
||||
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
|
||||
uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01);
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
switch (htp_op) {
|
||||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
default:
|
||||
break;
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
}
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
// Process block in VTCM
|
||||
switch (htp_op) {
|
||||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
|
||||
dst_row_size, dst_row_size_aligned, block_size);
|
||||
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
|
||||
FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0],
|
||||
src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
|
||||
unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
|
||||
octx->src0_nrows_per_thread);
|
||||
}
|
||||
|
||||
static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t unary_op_func;
|
||||
const char * op_type = NULL;
|
||||
const char * op_type = NULL;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_RMS_NORM:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "rmsnorm-f32";
|
||||
op_type = "rmsnorm-f32";
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "scale-f32";
|
||||
op_type = "scale-f32";
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqr-f32";
|
||||
op_type = "sqr-f32";
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqrt-f32";
|
||||
op_type = "sqrt-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -294,32 +307,61 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
// Double buffering requires 2x size per buffer
|
||||
|
||||
size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
|
||||
size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (vtcm_row_per_thread == 0) {
|
||||
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size_per_row * n_threads);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2;
|
||||
octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2;
|
||||
|
||||
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
||||
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
struct htp_unary_context uctx = {
|
||||
.octx = octx,
|
||||
.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs,
|
||||
.src0_nrows = src0_nrows,
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
|
||||
.data_src0 = (const uint8_t *)src0->data,
|
||||
.data_dst = (uint8_t *)dst->data,
|
||||
|
||||
.src0_row_size = src0_row_size,
|
||||
.dst_row_size = dst_row_size,
|
||||
|
||||
.src0_row_size_aligned = src0_row_size_aligned,
|
||||
.dst_row_size_aligned = dst_row_size_aligned,
|
||||
|
||||
.src0_spad_half_size = octx->src0_spad.size_per_thread / 2,
|
||||
.dst_spad_half_size = octx->dst_spad.size_per_thread / 2,
|
||||
|
||||
.block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
|
||||
.nc = src0->ne[0],
|
||||
};
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
||||
@@ -484,7 +484,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_scale_f32, kernel_scale_f32_4;
|
||||
cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;
|
||||
cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;
|
||||
cl_kernel kernel_mean_f32;
|
||||
cl_kernel kernel_mean_f32, kernel_mean_f32_4;
|
||||
cl_kernel kernel_silu, kernel_silu_4;
|
||||
cl_kernel kernel_gelu, kernel_gelu_4;
|
||||
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
|
||||
@@ -543,15 +543,15 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_solve_tri_f32;
|
||||
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
|
||||
cl_kernel kernel_argsort_f32_i32;
|
||||
cl_kernel kernel_sum_rows_f32;
|
||||
cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4;
|
||||
cl_kernel kernel_repeat_f32;
|
||||
cl_kernel kernel_pad;
|
||||
cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc;
|
||||
cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc;
|
||||
cl_kernel kernel_expm1_f32_nd;
|
||||
cl_kernel kernel_expm1_f16_nd;
|
||||
cl_kernel kernel_softplus_f32_nd;
|
||||
cl_kernel kernel_softplus_f16_nd;
|
||||
cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc;
|
||||
cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc;
|
||||
cl_kernel kernel_softplus_f32, kernel_softplus_f32_4, kernel_softplus_f32_nc;
|
||||
cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc;
|
||||
cl_kernel kernel_upscale;
|
||||
cl_kernel kernel_upscale_bilinear;
|
||||
cl_kernel kernel_concat_f32;
|
||||
@@ -1837,6 +1837,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_mean_f32_4 = clCreateKernel(prog, "kernel_mean_f32_4", &err), err));
|
||||
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
@@ -1874,6 +1875,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_sum_rows_f32_4 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32_4", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
@@ -1978,20 +1980,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
#else
|
||||
const std::string kernel_src = read_file("expm1.cl");
|
||||
#endif
|
||||
cl_program prog;
|
||||
if (!kernel_src.empty()) {
|
||||
prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
CL_CHECK((backend_ctx->kernel_expm1_f32_nd = clCreateKernel(prog, "kernel_expm1_f32_nd", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_expm1_f16_nd = clCreateKernel(prog, "kernel_expm1_f16_nd", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
} else {
|
||||
GGML_LOG_WARN("ggml_opencl: expm1 kernel source not found or empty. Expm1 operation will not be available.\n");
|
||||
prog = nullptr;
|
||||
backend_ctx->kernel_expm1_f32_nd = nullptr;
|
||||
backend_ctx->kernel_expm1_f16_nd = nullptr;
|
||||
}
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
CL_CHECK((backend_ctx->kernel_expm1_f32 = clCreateKernel(prog, "kernel_expm1_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_expm1_f32_4 = clCreateKernel(prog, "kernel_expm1_f32_4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_expm1_f32_nc = clCreateKernel(prog, "kernel_expm1_f32_nc", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_expm1_f16 = clCreateKernel(prog, "kernel_expm1_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_expm1_f16_4 = clCreateKernel(prog, "kernel_expm1_f16_4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_expm1_f16_nc = clCreateKernel(prog, "kernel_expm1_f16_nc", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// softplus
|
||||
@@ -2003,20 +2001,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
#else
|
||||
const std::string kernel_src = read_file("softplus.cl");
|
||||
#endif
|
||||
cl_program prog;
|
||||
if (!kernel_src.empty()) {
|
||||
prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
} else {
|
||||
GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n");
|
||||
prog = nullptr;
|
||||
backend_ctx->kernel_softplus_f32_nd = nullptr;
|
||||
backend_ctx->kernel_softplus_f16_nd = nullptr;
|
||||
}
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
CL_CHECK((backend_ctx->kernel_softplus_f32 = clCreateKernel(prog, "kernel_softplus_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_softplus_f32_4 = clCreateKernel(prog, "kernel_softplus_f32_4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_softplus_f32_nc = clCreateKernel(prog, "kernel_softplus_f32_nc", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_softplus_f16 = clCreateKernel(prog, "kernel_softplus_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_softplus_f16_4 = clCreateKernel(prog, "kernel_softplus_f16_4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_softplus_f16_nc = clCreateKernel(prog, "kernel_softplus_f16_nc", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// upscale
|
||||
@@ -3463,11 +3457,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
case GGML_UNARY_OP_TANH:
|
||||
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
|
||||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
|
||||
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
|
||||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
|
||||
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -3587,7 +3579,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
}
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const ggml_tensor * q = op->src[0];
|
||||
@@ -6400,7 +6392,6 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
GGML_UNUSED(src1);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
@@ -6423,7 +6414,14 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
const cl_ulong nb2 = dst->nb[2];
|
||||
const cl_ulong nb3 = dst->nb[3];
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_mean_f32;
|
||||
cl_kernel kernel;
|
||||
|
||||
const bool is_c4 = ne00 % 4 == 0;
|
||||
if (is_c4) {
|
||||
kernel = backend_ctx->kernel_mean_f32_4;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_mean_f32;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
@@ -6440,7 +6438,7 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
|
||||
size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)64, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
@@ -7388,18 +7386,8 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0_abs = extra0->offset + src0->view_offs;
|
||||
cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
|
||||
|
||||
cl_kernel kernel;
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_expm1_f32_nd;
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
kernel = backend_ctx->kernel_expm1_f16_nd;
|
||||
} else {
|
||||
GGML_ASSERT(false && "Unsupported type for ggml_cl_expm1");
|
||||
}
|
||||
GGML_ASSERT(kernel != nullptr);
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
@@ -7411,70 +7399,74 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons
|
||||
const cl_ulong nb02 = src0->nb[2];
|
||||
const cl_ulong nb03 = src0->nb[3];
|
||||
|
||||
const int ne10 = dst->ne[0];
|
||||
const int ne11 = dst->ne[1];
|
||||
const int ne12 = dst->ne[2];
|
||||
const int ne13 = dst->ne[3];
|
||||
const cl_ulong nb0 = dst->nb[0];
|
||||
const cl_ulong nb1 = dst->nb[1];
|
||||
const cl_ulong nb2 = dst->nb[2];
|
||||
const cl_ulong nb3 = dst->nb[3];
|
||||
|
||||
const cl_ulong nb10 = dst->nb[0];
|
||||
const cl_ulong nb11 = dst->nb[1];
|
||||
const cl_ulong nb12 = dst->nb[2];
|
||||
const cl_ulong nb13 = dst->nb[3];
|
||||
cl_kernel kernel;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
|
||||
|
||||
size_t global_work_size[3];
|
||||
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
|
||||
return;
|
||||
}
|
||||
global_work_size[0] = (size_t)ne10;
|
||||
global_work_size[1] = (size_t)ne11;
|
||||
global_work_size[2] = (size_t)ne12;
|
||||
|
||||
size_t lws0 = 16, lws1 = 4, lws2 = 1;
|
||||
if (ne10 < 16) lws0 = ne10;
|
||||
if (ne11 < 4) lws1 = ne11;
|
||||
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
|
||||
|
||||
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
|
||||
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
|
||||
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
|
||||
|
||||
|
||||
size_t local_work_size[] = {lws0, lws1, lws2};
|
||||
|
||||
size_t* local_work_size_ptr = local_work_size;
|
||||
if (!backend_ctx->non_uniform_workgroups) {
|
||||
if (global_work_size[0] % local_work_size[0] != 0 ||
|
||||
global_work_size[1] % local_work_size[1] != 0 ||
|
||||
global_work_size[2] % local_work_size[2] != 0) {
|
||||
local_work_size_ptr = NULL;
|
||||
if (ggml_is_contiguous(src0)) {
|
||||
// Handle contiguous input
|
||||
int n = ggml_nelements(dst);
|
||||
if (n % 4 == 0) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_expm1_f32_4;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_expm1_f16_4;
|
||||
}
|
||||
n /= 4;
|
||||
} else {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_expm1_f32;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_expm1_f16;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
|
||||
size_t global_work_size[] = {(size_t)n, 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
|
||||
size_t * local_work_size_ptr = local_work_size;
|
||||
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
|
||||
local_work_size_ptr = nullptr;
|
||||
}
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
||||
} else {
|
||||
// Handle non-contiguous input
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_expm1_f32_nc;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_expm1_f16_nc;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
|
||||
|
||||
int nth = 64;
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
@@ -7490,18 +7482,8 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0_abs = extra0->offset + src0->view_offs;
|
||||
cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
|
||||
|
||||
cl_kernel kernel;
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_softplus_f32_nd;
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
kernel = backend_ctx->kernel_softplus_f16_nd;
|
||||
} else {
|
||||
GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus");
|
||||
}
|
||||
GGML_ASSERT(kernel != nullptr);
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
@@ -7513,70 +7495,74 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
const cl_ulong nb02 = src0->nb[2];
|
||||
const cl_ulong nb03 = src0->nb[3];
|
||||
|
||||
const int ne10 = dst->ne[0];
|
||||
const int ne11 = dst->ne[1];
|
||||
const int ne12 = dst->ne[2];
|
||||
const int ne13 = dst->ne[3];
|
||||
const cl_ulong nb0 = dst->nb[0];
|
||||
const cl_ulong nb1 = dst->nb[1];
|
||||
const cl_ulong nb2 = dst->nb[2];
|
||||
const cl_ulong nb3 = dst->nb[3];
|
||||
|
||||
const cl_ulong nb10 = dst->nb[0];
|
||||
const cl_ulong nb11 = dst->nb[1];
|
||||
const cl_ulong nb12 = dst->nb[2];
|
||||
const cl_ulong nb13 = dst->nb[3];
|
||||
cl_kernel kernel;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
|
||||
|
||||
size_t global_work_size[3];
|
||||
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
|
||||
return;
|
||||
}
|
||||
global_work_size[0] = (size_t)ne10;
|
||||
global_work_size[1] = (size_t)ne11;
|
||||
global_work_size[2] = (size_t)ne12;
|
||||
|
||||
size_t lws0 = 16, lws1 = 4, lws2 = 1;
|
||||
if (ne10 < 16) lws0 = ne10;
|
||||
if (ne11 < 4) lws1 = ne11;
|
||||
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
|
||||
|
||||
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
|
||||
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
|
||||
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
|
||||
|
||||
|
||||
size_t local_work_size[] = {lws0, lws1, lws2};
|
||||
|
||||
size_t* local_work_size_ptr = local_work_size;
|
||||
if (!backend_ctx->non_uniform_workgroups) {
|
||||
if (global_work_size[0] % local_work_size[0] != 0 ||
|
||||
global_work_size[1] % local_work_size[1] != 0 ||
|
||||
global_work_size[2] % local_work_size[2] != 0) {
|
||||
local_work_size_ptr = NULL;
|
||||
if (ggml_is_contiguous(src0)) {
|
||||
// Handle contiguous input
|
||||
int n = ggml_nelements(dst);
|
||||
if (n % 4 == 0) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_softplus_f32_4;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_softplus_f16_4;
|
||||
}
|
||||
n /= 4;
|
||||
} else {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_softplus_f32;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_softplus_f16;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
|
||||
size_t global_work_size[] = {(size_t)n, 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
|
||||
size_t * local_work_size_ptr = local_work_size;
|
||||
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
|
||||
local_work_size_ptr = nullptr;
|
||||
}
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
||||
} else {
|
||||
// Handle non-contiguous input
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_softplus_f32_nc;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_softplus_f16_nc;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
|
||||
|
||||
int nth = 64;
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {
|
||||
@@ -11088,7 +11074,6 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
GGML_UNUSED(src1);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
@@ -11111,7 +11096,14 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
const cl_ulong nb2 = dst->nb[2];
|
||||
const cl_ulong nb3 = dst->nb[3];
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_sum_rows_f32;
|
||||
cl_kernel kernel;
|
||||
|
||||
const bool is_c4 = ne00 % 4 == 0;
|
||||
if (is_c4) {
|
||||
kernel = backend_ctx->kernel_sum_rows_f32_4;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_sum_rows_f32;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
@@ -11128,7 +11120,7 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
|
||||
size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)64, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
|
||||
@@ -3,80 +3,111 @@
|
||||
//------------------------------------------------------------------------------
|
||||
// expm1
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_expm1_f32_nd(
|
||||
global void * p_src0_base,
|
||||
ulong off_src0_abs,
|
||||
global void * p_dst_base,
|
||||
ulong off_dst_abs,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
|
||||
kernel void kernel_expm1_f32(
|
||||
global const float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32_4(
|
||||
global const float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f16(
|
||||
global const half * src0,
|
||||
ulong offset0,
|
||||
global half * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global half*)((global char*)src0 + offset0);
|
||||
dst = (global half*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f16_4(
|
||||
global const half4 * src0,
|
||||
ulong offset0,
|
||||
global half4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global half4*)((global char*)src0 + offset0);
|
||||
dst = (global half4*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32_nc(
|
||||
global const char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
int i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int i2 = get_global_id(2);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
||||
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
*dst_val_ptr = exp(*src_val_ptr) - 1;
|
||||
}
|
||||
*y = exp(*x) - 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f16_nd(
|
||||
global void * p_src0_base,
|
||||
ulong off_src0_abs,
|
||||
global void * p_dst_base,
|
||||
ulong off_dst_abs,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
kernel void kernel_expm1_f16_nc(
|
||||
global const char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
int i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int i2 = get_global_id(2);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
||||
global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
*dst_val_ptr = exp(*src_val_ptr) - 1;
|
||||
}
|
||||
*y = exp(*x) - 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
// Most devices have max workgroup size of 1024, so this is enough for subgroup
|
||||
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
|
||||
#define MAX_SUBGROUPS 64
|
||||
kernel void kernel_mean_f32(
|
||||
global float * src0,
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
@@ -15,25 +20,121 @@ kernel void kernel_mean_f32(
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i3 = get_global_id(2);
|
||||
int i2 = get_global_id(1);
|
||||
int i1 = get_global_id(0);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
const int lid = get_local_id(0);
|
||||
const int lsize = get_local_size(0);
|
||||
|
||||
const uint sg_size = get_sub_group_size();
|
||||
const uint sg_id = get_sub_group_id();
|
||||
const uint sg_lid = get_sub_group_local_id();
|
||||
|
||||
__local float lmem[MAX_SUBGROUPS];
|
||||
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float row_sum = 0;
|
||||
|
||||
for (int i0 = 0; i0 < ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
if(sg_id == 0){
|
||||
lmem[sg_lid] = 0.0f;
|
||||
}
|
||||
|
||||
dst_row[0] = row_sum / ne00;
|
||||
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i0 = lid; i0 < ne00; i0 += lsize) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if(sg_lid == 0){
|
||||
lmem[sg_id] = sumf;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
sumf = lmem[sg_lid];
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
if (lid == 0) {
|
||||
dst_row[0] = sumf / ne00;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mean_f32_4(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
const int lid = get_local_id(0);
|
||||
const int lsize = get_local_size(0);
|
||||
|
||||
const uint sg_size = get_sub_group_size();
|
||||
const uint sg_id = get_sub_group_id();
|
||||
const uint sg_lid = get_sub_group_local_id();
|
||||
|
||||
__local float lmem[MAX_SUBGROUPS];
|
||||
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if(sg_id == 0){
|
||||
lmem[sg_lid] = 0.0f;
|
||||
}
|
||||
|
||||
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float4 sum_vec = (float4)0.0f;
|
||||
|
||||
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
|
||||
sum_vec += src_row[i0];
|
||||
}
|
||||
|
||||
float sumf = dot(sum_vec, (float4)(1.0f));
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if(sg_lid == 0){
|
||||
lmem[sg_id] = sumf;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
sumf = lmem[sg_lid];
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
if (lid == 0) {
|
||||
dst_row[0] = sumf / ne00;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,86 +3,114 @@
|
||||
//------------------------------------------------------------------------------
|
||||
// softplus
|
||||
//------------------------------------------------------------------------------
|
||||
inline float softplus_f32(float x){
|
||||
float ax = fabs(x);
|
||||
float m = fmax(x, 0.0f);
|
||||
return log1p(exp(-ax)) + m;
|
||||
|
||||
kernel void kernel_softplus_f32(
|
||||
global const float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32_nd(
|
||||
global void * p_src0_base,
|
||||
ulong off_src0_abs,
|
||||
global void * p_dst_base,
|
||||
ulong off_dst_abs,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
kernel void kernel_softplus_f32_4(
|
||||
global const float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f16(
|
||||
global const half * src0,
|
||||
ulong offset0,
|
||||
global half * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global half*)((global char*)src0 + offset0);
|
||||
dst = (global half*)((global char*)dst + offsetd);
|
||||
|
||||
const float x = convert_float(src0[get_global_id(0)]);
|
||||
dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f16_4(
|
||||
global const half4 * src0,
|
||||
ulong offset0,
|
||||
global half4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global half4*)((global char*)src0 + offset0);
|
||||
dst = (global half4*)((global char*)dst + offsetd);
|
||||
|
||||
const float4 x = convert_float4(src0[get_global_id(0)]);
|
||||
dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32_nc(
|
||||
global const char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
int i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int i2 = get_global_id(2);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
||||
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
*dst_val_ptr = softplus_f32(*src_val_ptr);
|
||||
}
|
||||
*y = (*x > 20.0f) ? *x : log(1.0f + exp(*x));
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f16_nd(
|
||||
global void * p_src0_base,
|
||||
ulong off_src0_abs,
|
||||
global void * p_dst_base,
|
||||
ulong off_dst_abs,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
kernel void kernel_softplus_f16_nc(
|
||||
global const char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
int i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int i2 = get_global_id(2);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
||||
global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global half * hy = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
*dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr)));
|
||||
}
|
||||
const float x = convert_float(*hx);
|
||||
*hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
// Most devices have max workgroup size of 1024, so this is enough for subgroup
|
||||
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
|
||||
#define MAX_SUBGROUPS 64
|
||||
kernel void kernel_sum_rows_f32(
|
||||
global float * src0,
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
@@ -15,25 +20,121 @@ kernel void kernel_sum_rows_f32(
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i3 = get_global_id(2);
|
||||
int i2 = get_global_id(1);
|
||||
int i1 = get_global_id(0);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
const int lid = get_local_id(0);
|
||||
const int lsize = get_local_size(0);
|
||||
|
||||
const uint sg_size = get_sub_group_size();
|
||||
const uint sg_id = get_sub_group_id();
|
||||
const uint sg_lid = get_sub_group_local_id();
|
||||
|
||||
__local float lmem[MAX_SUBGROUPS];
|
||||
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float row_sum = 0;
|
||||
|
||||
for (int i0 = 0; i0 < ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
if(sg_id == 0){
|
||||
lmem[sg_lid] = 0.0f;
|
||||
}
|
||||
|
||||
dst_row[0] = row_sum;
|
||||
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i0 = lid; i0 < ne00; i0 += lsize) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if(sg_lid == 0){
|
||||
lmem[sg_id] = sumf;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
sumf = lmem[sg_lid];
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
if (lid == 0) {
|
||||
dst_row[0] = sumf;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_sum_rows_f32_4(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
const int lid = get_local_id(0);
|
||||
const int lsize = get_local_size(0);
|
||||
|
||||
const uint sg_size = get_sub_group_size();
|
||||
const uint sg_id = get_sub_group_id();
|
||||
const uint sg_lid = get_sub_group_local_id();
|
||||
|
||||
__local float lmem[MAX_SUBGROUPS];
|
||||
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if(sg_id == 0){
|
||||
lmem[sg_lid] = 0.0f;
|
||||
}
|
||||
|
||||
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float4 sum_vec = (float4)0.0f;
|
||||
|
||||
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
|
||||
sum_vec += src_row[i0];
|
||||
}
|
||||
|
||||
float sumf = dot(sum_vec, (float4)(1.0f));
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if(sg_lid == 0){
|
||||
lmem[sg_id] = sumf;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
sumf = lmem[sg_lid];
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
if (lid == 0) {
|
||||
dst_row[0] = sumf;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +55,11 @@ void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
const int32_t* src2_d = (const int32_t*)src2->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
int threads = std::min((int)ne00, 768); // cols
|
||||
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
|
||||
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||
|
||||
int threads = std::min((unsigned int)ne00, max_work_group_size); // cols
|
||||
|
||||
ctx.stream()->parallel_for(
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads),
|
||||
|
||||
@@ -11,8 +11,8 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s00,*/ int s01, int s02, int s03,
|
||||
/*int s10,*/ int s11, int s12, int s13,
|
||||
int s00, int s01, int s02, int s03,
|
||||
int s10, int s11, int s12, int s13,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -44,7 +44,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
for (int i0 = i0s; i0 < ne0;
|
||||
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,8 +53,8 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s00,*/ int s01, int s02, int s03,
|
||||
/*int s10,*/ int s11, int s12, int s13,
|
||||
int s00, int s01, int s02, int s03,
|
||||
int s10, int s11, int s12, int s13,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
@@ -82,7 +82,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);
|
||||
}
|
||||
|
||||
|
||||
@@ -95,7 +95,8 @@ struct bin_bcast_sycl {
|
||||
const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
|
||||
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
|
||||
const bool src1_is_contiguous, const bool src0_is_permuted, const bool src1_is_permuted,
|
||||
queue_ptr stream) {
|
||||
int nr0 = ne10 / ne0;
|
||||
int nr1 = ne11/ne1;
|
||||
int nr2 = ne12/ne2;
|
||||
@@ -123,7 +124,7 @@ struct bin_bcast_sycl {
|
||||
cnb[3] *= cne[3];
|
||||
};
|
||||
|
||||
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
|
||||
if (src0_is_contiguous && src1_is_contiguous && !src0_is_permuted && !src1_is_permuted) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (nr[i] != 1) {
|
||||
break;
|
||||
@@ -164,7 +165,7 @@ struct bin_bcast_sycl {
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
// size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
@@ -196,9 +197,6 @@ struct bin_bcast_sycl {
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
||||
@@ -232,8 +230,8 @@ struct bin_bcast_sycl {
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast_unravel<bin_op>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
|
||||
s03, s11, s12, s13, item_ct1);
|
||||
ne10, ne11, ne12, ne13, s1, s2, s3, s00, s01, s02,
|
||||
s03, s10, s11, s12, s13, item_ct1);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
@@ -251,7 +249,7 @@ struct bin_bcast_sycl {
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
||||
ne2, ne3, ne10, ne11, ne12, ne13,
|
||||
s1, s2, s3, s01, s02, s03, s11, s12, s13,
|
||||
s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13,
|
||||
item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -268,24 +266,27 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10,
|
||||
ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
|
||||
ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
|
||||
ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01,
|
||||
ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
|
||||
nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst),
|
||||
nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
|
||||
main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
|
||||
op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02,
|
||||
ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
|
||||
nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
|
||||
nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
|
||||
main_stream);
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
|
||||
op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
|
||||
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
|
||||
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
|
||||
main_stream);
|
||||
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
|
||||
op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
|
||||
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
|
||||
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
|
||||
main_stream);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
|
||||
@@ -7,9 +7,21 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
static uint32_t validate_graph_operation(size_t cgraph_size, uint32_t shmem_res_id, const char * operation) {
|
||||
if (cgraph_size == 0) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Zero-size computation graph\n", operation);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// place-holder: validate that the size of shmem_res_id is <= cgraph_size
|
||||
// need to add another method in the Virgl->APIR callback interface
|
||||
GGML_UNUSED(shmem_res_id);
|
||||
|
||||
return 0; // Valid
|
||||
}
|
||||
|
||||
uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(enc);
|
||||
|
||||
static bool async_backend_initialized = false;
|
||||
static bool async_backend;
|
||||
@@ -34,10 +46,26 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v
|
||||
size_t cgraph_size;
|
||||
apir_decode_size_t(dec, &cgraph_size);
|
||||
|
||||
if (validate_graph_operation(cgraph_size, shmem_res_id, __func__) != 0) {
|
||||
apir_decoder_set_fatal(dec);
|
||||
return 1;
|
||||
}
|
||||
|
||||
apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size);
|
||||
|
||||
ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size);
|
||||
|
||||
if (!cgraph || apir_decoder_get_fatal(&secondary_dec)) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to deserialize computation graph\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (cgraph->n_nodes < 0 || cgraph->n_leafs < 0) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid negative node/leaf count: nodes=%d leafs=%d\n", __func__,
|
||||
cgraph->n_nodes, cgraph->n_leafs);
|
||||
return 1;
|
||||
}
|
||||
|
||||
ggml_status status;
|
||||
#if APIR_BACKEND_CHECK_SUPPORTS_OP == 1
|
||||
for (int idx = 0; idx < cgraph->n_nodes; idx++) {
|
||||
@@ -45,7 +73,8 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v
|
||||
if (dev->iface.supports_op(dev, op)) {
|
||||
continue;
|
||||
}
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op));
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", __func__, idx,
|
||||
ggml_op_desc(op));
|
||||
|
||||
status = GGML_STATUS_ABORTED;
|
||||
apir_encode_ggml_status(enc, &status);
|
||||
@@ -53,9 +82,17 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Check if backend is properly initialized
|
||||
if (!bck) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Backend not initialized (bck is null)\n", __func__);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
status = bck->iface.graph_compute(bck, cgraph);
|
||||
|
||||
if (async_backend) {
|
||||
if (async_backend && bck->iface.synchronize) {
|
||||
bck->iface.synchronize(bck);
|
||||
}
|
||||
|
||||
|
||||
@@ -85,7 +85,19 @@ uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * d
|
||||
|
||||
const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);
|
||||
|
||||
size_t value = buft->iface.get_alloc_size(buft, op);
|
||||
// Check for decode error
|
||||
if (op == nullptr) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to decode tensor\n", __func__);
|
||||
apir_decoder_set_fatal(dec);
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t value;
|
||||
if (buft->iface.get_alloc_size) {
|
||||
value = buft->iface.get_alloc_size(buft, op);
|
||||
} else {
|
||||
value = ggml_nbytes(op); // Default fallback
|
||||
}
|
||||
|
||||
apir_encode_size_t(enc, &value);
|
||||
|
||||
|
||||
@@ -6,11 +6,26 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
static uint32_t validate_buffer_operation(size_t offset, size_t size, const char * operation) {
|
||||
// Only check for critical integer overflow - no arbitrary size limits
|
||||
if (offset > SIZE_MAX - size) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Integer overflow in offset+size: %zu + %zu\n", operation, offset, size);
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0; // Valid
|
||||
}
|
||||
|
||||
uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
|
||||
GGML_UNUSED(ctx);
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
if (!buffer || apir_decoder_get_fatal(dec)) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer);
|
||||
apir_encode_uintptr_t(enc, &base);
|
||||
|
||||
@@ -24,6 +39,11 @@ uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
if (!buffer || apir_decoder_get_fatal(dec)) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
ggml_tensor * tensor;
|
||||
// safe to remove the const qualifier here
|
||||
tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);
|
||||
@@ -37,6 +57,10 @@ uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl
|
||||
size_t size;
|
||||
apir_decode_size_t(dec, &size);
|
||||
|
||||
if (validate_buffer_operation(offset, size, __func__) != 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
|
||||
|
||||
if (!shmem_data) {
|
||||
@@ -56,6 +80,11 @@ uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
if (!buffer || apir_decoder_get_fatal(dec)) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const ggml_tensor * tensor;
|
||||
// safe to remove the const qualifier here
|
||||
tensor = apir_decode_ggml_tensor(dec);
|
||||
@@ -69,6 +98,10 @@ uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl
|
||||
size_t size;
|
||||
apir_decode_size_t(dec, &size);
|
||||
|
||||
if (validate_buffer_operation(offset, size, __func__) != 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
|
||||
if (!shmem_data) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__);
|
||||
@@ -86,6 +119,11 @@ uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
if (!buffer || apir_decoder_get_fatal(dec)) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const ggml_tensor * src;
|
||||
// safe to remove the const qualifier here
|
||||
src = apir_decode_ggml_tensor(dec);
|
||||
@@ -105,6 +143,11 @@ uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
if (!buffer || apir_decoder_get_fatal(dec)) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
uint8_t value;
|
||||
apir_decode_uint8_t(dec, &value);
|
||||
|
||||
@@ -120,6 +163,11 @@ uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virg
|
||||
ggml_backend_buffer_t buffer;
|
||||
buffer = apir_decode_ggml_buffer(dec);
|
||||
|
||||
if (!buffer || apir_decoder_get_fatal(dec)) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!apir_untrack_backend_buffer(buffer)) {
|
||||
GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: unknown buffer %p\n", __func__, (void *) buffer);
|
||||
return 1;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "backend-dispatched.h"
|
||||
#include "backend-virgl-apir.h"
|
||||
|
||||
#include "backend-virgl-apir.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-impl.h"
|
||||
@@ -28,19 +28,24 @@ uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) {
|
||||
return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED;
|
||||
}
|
||||
|
||||
if (!reg->iface.get_device_count(reg)) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device found\n", __func__);
|
||||
size_t device_count = reg->iface.get_device_count(reg);
|
||||
if (!device_count) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: no device found\n", __func__);
|
||||
return APIR_BACKEND_INITIALIZE_NO_DEVICE;
|
||||
}
|
||||
|
||||
dev = reg->iface.get_device(reg, 0);
|
||||
|
||||
if (!dev) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device received\n", __func__);
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: failed to get device\n", __func__);
|
||||
return APIR_BACKEND_INITIALIZE_NO_DEVICE;
|
||||
}
|
||||
|
||||
bck = dev->iface.init_backend(dev, NULL);
|
||||
if (!bck) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed\n", __func__);
|
||||
return APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED;
|
||||
}
|
||||
|
||||
return APIR_BACKEND_INITIALIZE_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -32,64 +32,6 @@ uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virg
|
||||
/* backend */
|
||||
uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
|
||||
|
||||
static inline const char * backend_dispatch_command_name(ApirBackendCommandType type) {
|
||||
switch (type) {
|
||||
/* device */
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT:
|
||||
return "backend_device_get_device_count";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_COUNT:
|
||||
return "backend_device_get_count";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_NAME:
|
||||
return "backend_device_get_name";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION:
|
||||
return "backend_device_get_description";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_TYPE:
|
||||
return "backend_device_get_type";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY:
|
||||
return "backend_device_get_memory";
|
||||
case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP:
|
||||
return "backend_device_supports_op";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE:
|
||||
return "backend_device_get_buffer_type";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_PROPS:
|
||||
return "backend_device_get_props";
|
||||
case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR:
|
||||
return "backend_device_buffer_from_ptr";
|
||||
/* buffer-type */
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME:
|
||||
return "backend_buffer_type_get_name";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT:
|
||||
return "backend_buffer_type_get_alignment";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE:
|
||||
return "backend_buffer_type_get_max_size";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST:
|
||||
return "backend_buffer_type_is_host (DEPRECATED)";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER:
|
||||
return "backend_buffer_type_alloc_buffer";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE:
|
||||
return "backend_buffer_type_get_alloc_size";
|
||||
/* buffer */
|
||||
case APIR_COMMAND_TYPE_BUFFER_GET_BASE:
|
||||
return "backend_buffer_get_base";
|
||||
case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR:
|
||||
return "backend_buffer_set_tensor";
|
||||
case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR:
|
||||
return "backend_buffer_get_tensor";
|
||||
case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR:
|
||||
return "backend_buffer_cpy_tensor";
|
||||
case APIR_COMMAND_TYPE_BUFFER_CLEAR:
|
||||
return "backend_buffer_clear";
|
||||
case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER:
|
||||
return "backend_buffer_free_buffer";
|
||||
/* backend */
|
||||
case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE:
|
||||
return "backend_backend_graph_compute";
|
||||
|
||||
default:
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
#include <cstdint>
|
||||
#include <cstddef>
|
||||
|
||||
@@ -10,6 +11,7 @@
|
||||
#include "shared/apir_backend.h"
|
||||
#include "shared/apir_cs.h"
|
||||
#include "shared/apir_cs_ggml.h"
|
||||
// clang-format on
|
||||
|
||||
#define GGML_VIRTGPU_BCK "ggml-virtgpu-backend: "
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ struct virgl_apir_callbacks {
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs);
|
||||
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs);
|
||||
void apir_backend_deinit(uint32_t virgl_ctx_id);
|
||||
uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id,
|
||||
virgl_apir_callbacks * virgl_cbs,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#include "backend-dispatched.h"
|
||||
#include "backend-virgl-apir.h"
|
||||
|
||||
#include "shared/api_remoting.h"
|
||||
#include "shared/apir_backend.h"
|
||||
#include "shared/apir_cs.h"
|
||||
@@ -17,10 +16,10 @@
|
||||
#define GGML_DEFAULT_BACKEND_REG "ggml_backend_init"
|
||||
|
||||
static void * backend_library_handle = NULL;
|
||||
static FILE * apir_logfile = NULL;
|
||||
static FILE * apir_logfile = NULL;
|
||||
|
||||
static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) {
|
||||
FILE * logfile = (FILE *)user_data;
|
||||
FILE * logfile = (FILE *) user_data;
|
||||
fprintf(logfile, "[%d] %s", level, text);
|
||||
fflush(logfile);
|
||||
}
|
||||
@@ -48,9 +47,9 @@ void apir_backend_deinit(uint32_t virgl_ctx_id) {
|
||||
}
|
||||
|
||||
#define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path"
|
||||
#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg"
|
||||
#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg"
|
||||
|
||||
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs) {
|
||||
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs) {
|
||||
const char * dlsym_error;
|
||||
|
||||
const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV);
|
||||
@@ -63,15 +62,13 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct
|
||||
}
|
||||
}
|
||||
|
||||
const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY);
|
||||
const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY);
|
||||
const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY);
|
||||
const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG;
|
||||
const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG;
|
||||
|
||||
if (!library_name) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
|
||||
"%s: cannot open the GGML library: env var '%s' not defined\n",
|
||||
__func__, APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV);
|
||||
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: env var '%s' not defined\n", __func__,
|
||||
APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV);
|
||||
|
||||
return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
|
||||
}
|
||||
@@ -79,16 +76,14 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct
|
||||
backend_library_handle = dlopen(library_name, RTLD_LAZY);
|
||||
|
||||
if (!backend_library_handle) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
|
||||
"%s: cannot open the GGML library: %s\n", __func__, dlerror());
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: %s\n", __func__, dlerror());
|
||||
|
||||
return APIR_LOAD_LIBRARY_CANNOT_OPEN;
|
||||
}
|
||||
|
||||
if (!library_reg) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
|
||||
"%s: cannot register the GGML library: env var '%s' not defined\n",
|
||||
__func__, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV);
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot register the GGML library: env var '%s' not defined\n", __func__,
|
||||
APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV);
|
||||
|
||||
return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
|
||||
}
|
||||
@@ -96,11 +91,9 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct
|
||||
void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg);
|
||||
dlsym_error = dlerror();
|
||||
if (dlsym_error) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
|
||||
"%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n",
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n",
|
||||
__func__, library_reg, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error);
|
||||
|
||||
|
||||
return APIR_LOAD_LIBRARY_SYMBOL_MISSING;
|
||||
}
|
||||
|
||||
@@ -132,13 +125,12 @@ uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id,
|
||||
|
||||
virgl_apir_context ctx = {
|
||||
.ctx_id = virgl_ctx_id,
|
||||
.iface = virgl_cbs,
|
||||
.iface = virgl_cbs,
|
||||
};
|
||||
|
||||
if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
|
||||
"%s: Received an invalid dispatch index (%d >= %d)\n",
|
||||
__func__, cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT);
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Received an invalid dispatch index (%d >= %d)\n", __func__, cmd_type,
|
||||
APIR_BACKEND_DISPATCH_TABLE_COUNT);
|
||||
return APIR_BACKEND_FORWARD_INDEX_INVALID;
|
||||
}
|
||||
|
||||
|
||||
@@ -16,28 +16,32 @@ enum ApirCommandType {
|
||||
APIR_COMMAND_TYPE_LOADLIBRARY = 1,
|
||||
APIR_COMMAND_TYPE_FORWARD = 2,
|
||||
|
||||
APIR_COMMAND_TYPE_LENGTH = 3,
|
||||
APIR_COMMAND_TYPE_LENGTH = 3,
|
||||
};
|
||||
|
||||
typedef uint64_t ApirCommandFlags;
|
||||
|
||||
enum ApirLoadLibraryReturnCode {
|
||||
APIR_LOAD_LIBRARY_SUCCESS = 0,
|
||||
// these error codes are returned by the Virglrenderer APIR component
|
||||
APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1,
|
||||
APIR_LOAD_LIBRARY_ALREADY_LOADED = 2,
|
||||
APIR_LOAD_LIBRARY_ENV_VAR_MISSING = 3,
|
||||
APIR_LOAD_LIBRARY_CANNOT_OPEN = 4,
|
||||
APIR_LOAD_LIBRARY_SYMBOL_MISSING = 5,
|
||||
APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, // anything above this is a APIR backend library initialization return code
|
||||
// any value greater than this is an APIR *backend library* initialization return code
|
||||
APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6,
|
||||
};
|
||||
|
||||
enum ApirForwardReturnCode {
|
||||
APIR_FORWARD_SUCCESS = 0,
|
||||
APIR_FORWARD_NO_DISPATCH_FCT = 1,
|
||||
APIR_FORWARD_TIMEOUT = 2,
|
||||
|
||||
APIR_FORWARD_BASE_INDEX = 3, // anything above this is a APIR backend library forward return code
|
||||
} ;
|
||||
APIR_FORWARD_SUCCESS = 0,
|
||||
// these error codes are returned by the Virglrenderer APIR component
|
||||
APIR_FORWARD_NO_DISPATCH_FCT = 1,
|
||||
APIR_FORWARD_TIMEOUT = 2,
|
||||
APIR_FORWARD_FAILED_TO_SYNC_STREAMS = 3,
|
||||
// any value greater than this index an APIR *backend library* forward return code
|
||||
APIR_FORWARD_BASE_INDEX = 4,
|
||||
};
|
||||
|
||||
__attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) {
|
||||
switch (type) {
|
||||
@@ -82,6 +86,7 @@ __attribute__((unused)) static const char * apir_forward_error(ApirForwardReturn
|
||||
APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS);
|
||||
APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT);
|
||||
APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT);
|
||||
APIR_FORWARD_ERROR(APIR_FORWARD_FAILED_TO_SYNC_STREAMS);
|
||||
APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX);
|
||||
|
||||
return "Unknown APIR_COMMAND_TYPE_FORWARD error";
|
||||
|
||||
@@ -34,3 +34,61 @@ typedef enum ApirBackendCommandType {
|
||||
// last command_type index + 1
|
||||
APIR_BACKEND_DISPATCH_TABLE_COUNT = 23,
|
||||
} ApirBackendCommandType;
|
||||
|
||||
static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {
|
||||
switch (type) {
|
||||
/* device */
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT:
|
||||
return "device_get_device_count";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_COUNT:
|
||||
return "device_get_count";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_NAME:
|
||||
return "device_get_name";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION:
|
||||
return "device_get_description";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_TYPE:
|
||||
return "device_get_type";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY:
|
||||
return "device_get_memory";
|
||||
case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP:
|
||||
return "device_supports_op";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE:
|
||||
return "device_get_buffer_type";
|
||||
case APIR_COMMAND_TYPE_DEVICE_GET_PROPS:
|
||||
return "device_get_props";
|
||||
case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR:
|
||||
return "device_buffer_from_ptr";
|
||||
/* buffer-type */
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME:
|
||||
return "buffer_type_get_name";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT:
|
||||
return "buffer_type_get_alignment";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE:
|
||||
return "buffer_type_get_max_size";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST:
|
||||
return "buffer_type_is_host";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER:
|
||||
return "buffer_type_alloc_buffer";
|
||||
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE:
|
||||
return "buffer_type_get_alloc_size";
|
||||
/* buffer */
|
||||
case APIR_COMMAND_TYPE_BUFFER_GET_BASE:
|
||||
return "buffer_get_base";
|
||||
case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR:
|
||||
return "buffer_set_tensor";
|
||||
case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR:
|
||||
return "buffer_get_tensor";
|
||||
case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR:
|
||||
return "buffer_cpy_tensor";
|
||||
case APIR_COMMAND_TYPE_BUFFER_CLEAR:
|
||||
return "buffer_clear";
|
||||
case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER:
|
||||
return "buffer_free_buffer";
|
||||
/* backend */
|
||||
case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE:
|
||||
return "backend_graph_compute";
|
||||
|
||||
default:
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED 6
|
||||
#define APIR_BACKEND_INITIALIZE_ALREADY_INITED 7
|
||||
#define APIR_BACKEND_INITIALIZE_NO_DEVICE 8
|
||||
|
||||
#define APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED 9
|
||||
|
||||
// new entries here need to be added to the apir_backend_initialize_error function below
|
||||
|
||||
@@ -39,6 +39,10 @@ static const char * apir_backend_initialize_error(int code) {
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS);
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS);
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED);
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED);
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_ALREADY_INITED);
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_NO_DEVICE);
|
||||
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED);
|
||||
|
||||
return "Unknown APIR_BACKEND_INITIALIZE error:/";
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ struct apir_encoder {
|
||||
const char * start;
|
||||
const char * end;
|
||||
bool fatal;
|
||||
|
||||
};
|
||||
|
||||
struct apir_decoder {
|
||||
@@ -28,8 +27,8 @@ struct apir_decoder {
|
||||
|
||||
static apir_decoder apir_new_decoder(const char * ptr, size_t size) {
|
||||
apir_decoder dec = {
|
||||
.cur = ptr,
|
||||
.end = ptr + size,
|
||||
.cur = ptr,
|
||||
.end = ptr + size,
|
||||
.fatal = false,
|
||||
};
|
||||
|
||||
@@ -79,10 +78,7 @@ static inline bool apir_decoder_get_fatal(const apir_decoder * dec) {
|
||||
* encode peek
|
||||
*/
|
||||
|
||||
static inline bool apir_decoder_peek_internal(apir_decoder * dec,
|
||||
size_t size,
|
||||
void * val,
|
||||
size_t val_size) {
|
||||
static inline bool apir_decoder_peek_internal(apir_decoder * dec, size_t size, void * val, size_t val_size) {
|
||||
assert(val_size <= size);
|
||||
|
||||
if (unlikely(size > (size_t) (dec->end - dec->cur))) {
|
||||
@@ -332,8 +328,7 @@ static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t
|
||||
static inline void * apir_decoder_alloc_array(size_t size, size_t count) {
|
||||
size_t alloc_size;
|
||||
if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) {
|
||||
GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n",
|
||||
__func__, size, count);
|
||||
GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n", __func__, size, count);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@@ -352,20 +347,19 @@ static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) {
|
||||
|
||||
/* apir_buffer_type_host_handle_t */
|
||||
|
||||
static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc,
|
||||
static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc,
|
||||
const apir_buffer_type_host_handle_t * val) {
|
||||
apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
|
||||
}
|
||||
|
||||
static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec,
|
||||
static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec,
|
||||
apir_buffer_type_host_handle_t * val) {
|
||||
apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
|
||||
}
|
||||
|
||||
/* apir_buffer_host_handle_t */
|
||||
|
||||
static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc,
|
||||
const apir_buffer_host_handle_t * val) {
|
||||
static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, const apir_buffer_host_handle_t * val) {
|
||||
apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "apir_cs.h"
|
||||
#include "apir_cs_rpc.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
// ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer);
|
||||
|
||||
static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc,
|
||||
const apir_buffer_host_handle_t * handle);
|
||||
static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle);
|
||||
|
||||
static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec);
|
||||
|
||||
@@ -22,8 +21,7 @@ static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder
|
||||
return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
|
||||
}
|
||||
|
||||
static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec,
|
||||
uint32_t n_tensors) {
|
||||
static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, uint32_t n_tensors) {
|
||||
size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors;
|
||||
|
||||
return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
|
||||
@@ -45,9 +43,9 @@ static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) {
|
||||
}
|
||||
|
||||
ggml_init_params params{
|
||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
/*.mem_size =*/ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/NULL,
|
||||
/*.no_alloc =*/true,
|
||||
};
|
||||
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
@@ -105,6 +103,19 @@ static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec)
|
||||
|
||||
apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size);
|
||||
|
||||
// SECURITY: Validate buffer handle against tracked buffers to prevent
|
||||
// guest VM from providing arbitrary host memory addresses
|
||||
if (buffer) {
|
||||
extern std::unordered_set<ggml_backend_buffer_t> backend_buffers;
|
||||
if (backend_buffers.find(buffer) == backend_buffers.end()) {
|
||||
GGML_LOG_WARN("ggml-virtgpu-backend: %s: Invalid buffer handle from guest: %p\n", __func__,
|
||||
(void *) buffer);
|
||||
// Set fatal flag to prevent further processing with invalid handle
|
||||
apir_decoder_set_fatal(dec);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
@@ -5,6 +8,7 @@
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
// clang-format on
|
||||
|
||||
// ggml_tensor is serialized into apir_rpc_tensor
|
||||
struct apir_rpc_tensor {
|
||||
|
||||
@@ -34,6 +34,7 @@ static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml
|
||||
static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||
virtgpu * gpu = BUFT_TO_GPU(buft);
|
||||
|
||||
// Return the prefixed name that was built once during initialization
|
||||
return gpu->cached_buffer_type.name;
|
||||
}
|
||||
|
||||
@@ -53,9 +54,8 @@ static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buff
|
||||
const ggml_tensor * tensor) {
|
||||
virtgpu * gpu = BUFT_TO_GPU(buft);
|
||||
|
||||
if (tensor->buffer == NULL
|
||||
|| !tensor->buffer->context
|
||||
|| !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) {
|
||||
if (tensor->buffer == NULL || !tensor->buffer->context ||
|
||||
!buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) {
|
||||
return ggml_nbytes(tensor);
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) {
|
||||
virtgpu * gpu = DEV_TO_GPU(dev);
|
||||
|
||||
// Return the prefixed name that was built once during initialization
|
||||
return gpu->cached_device_info.name;
|
||||
}
|
||||
|
||||
@@ -22,7 +23,7 @@ static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_bac
|
||||
static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
virtgpu * gpu = DEV_TO_GPU(dev);
|
||||
|
||||
*free = gpu->cached_device_info.memory_free;
|
||||
*free = gpu->cached_device_info.memory_free;
|
||||
*total = gpu->cached_device_info.memory_total;
|
||||
}
|
||||
|
||||
@@ -72,7 +73,7 @@ static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_
|
||||
ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
virtgpu * gpu = DEV_TO_GPU(dev);
|
||||
|
||||
static std::atomic<bool> initialized = false;
|
||||
static std::atomic<bool> initialized = false;
|
||||
static ggml_backend_buffer_type buft;
|
||||
|
||||
if (!initialized) {
|
||||
@@ -95,7 +96,7 @@ ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_bac
|
||||
static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) {
|
||||
virtgpu * gpu = DEV_TO_GPU(dev);
|
||||
|
||||
static std::atomic<bool> initialized = false;
|
||||
static std::atomic<bool> initialized = false;
|
||||
static ggml_backend_buffer_type buft;
|
||||
|
||||
if (!initialized) {
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
void ggml_virtgpu_cleanup(virtgpu * gpu);
|
||||
|
||||
static virtgpu * apir_initialize() {
|
||||
static virtgpu * gpu = NULL;
|
||||
static std::atomic<bool> initialized = false;
|
||||
static virtgpu * gpu = NULL;
|
||||
static std::atomic<bool> initialized = false;
|
||||
|
||||
if (initialized) {
|
||||
// fast track
|
||||
@@ -31,29 +31,53 @@ static virtgpu * apir_initialize() {
|
||||
}
|
||||
|
||||
// Pre-fetch and cache all device information, it will not change
|
||||
gpu->cached_device_info.description = apir_device_get_description(gpu);
|
||||
gpu->cached_device_info.description = apir_device_get_description(gpu);
|
||||
if (!gpu->cached_device_info.description) {
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device description", __func__);
|
||||
}
|
||||
gpu->cached_device_info.name = apir_device_get_name(gpu);
|
||||
if (!gpu->cached_device_info.name) {
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device name", __func__);
|
||||
}
|
||||
gpu->cached_device_info.device_count = apir_device_get_count(gpu);
|
||||
gpu->cached_device_info.type = apir_device_get_type(gpu);
|
||||
|
||||
apir_device_get_memory(gpu,
|
||||
&gpu->cached_device_info.memory_free,
|
||||
&gpu->cached_device_info.memory_total);
|
||||
{
|
||||
// Get the remote name and create prefixed version
|
||||
char * rmt_device_name = apir_device_get_name(gpu);
|
||||
if (!rmt_device_name) {
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu device name", __func__);
|
||||
}
|
||||
|
||||
size_t device_name_len = strlen(rmt_device_name) + 11; // "[virtgpu] " + null terminator
|
||||
gpu->cached_device_info.name = (char *) malloc(device_name_len);
|
||||
if (!gpu->cached_device_info.name) {
|
||||
free(rmt_device_name);
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed device name", __func__);
|
||||
}
|
||||
snprintf(gpu->cached_device_info.name, device_name_len, "[virtgpu] %s", rmt_device_name);
|
||||
free(rmt_device_name);
|
||||
}
|
||||
|
||||
apir_device_get_memory(gpu, &gpu->cached_device_info.memory_free, &gpu->cached_device_info.memory_total);
|
||||
|
||||
apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu);
|
||||
gpu->cached_buffer_type.host_handle = buft_host_handle;
|
||||
gpu->cached_buffer_type.name = apir_buffer_type_get_name(gpu, buft_host_handle);
|
||||
if (!gpu->cached_buffer_type.name) {
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu buffer type name", __func__);
|
||||
{
|
||||
// Get the remote name and create prefixed version
|
||||
char * rmt_name = apir_buffer_type_get_name(gpu, buft_host_handle);
|
||||
if (!rmt_name) {
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu buffer type name", __func__);
|
||||
}
|
||||
|
||||
size_t prefixed_len = strlen(rmt_name) + 11; // "[virtgpu] " + null terminator
|
||||
gpu->cached_buffer_type.name = (char *) malloc(prefixed_len);
|
||||
if (!gpu->cached_buffer_type.name) {
|
||||
free(rmt_name);
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed buffer type name", __func__);
|
||||
}
|
||||
snprintf(gpu->cached_buffer_type.name, prefixed_len, "[virtgpu] %s", rmt_name);
|
||||
free(rmt_name);
|
||||
}
|
||||
gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle);
|
||||
gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle);
|
||||
|
||||
gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle);
|
||||
gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle);
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
@@ -98,7 +122,7 @@ static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) {
|
||||
static std::atomic<bool> initialized = false;
|
||||
|
||||
if (initialized) {
|
||||
return; // fast track
|
||||
return; // fast track
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "ggml-remoting.h"
|
||||
#include "../../include/ggml-virtgpu.h"
|
||||
#include "ggml-remoting.h"
|
||||
|
||||
static const char * ggml_backend_remoting_get_name(ggml_backend_t backend) {
|
||||
UNUSED(backend);
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include <string>
|
||||
|
||||
#define GGML_VIRTGPU_NAME "ggml-virtgpu"
|
||||
#define GGML_VIRTGPU "ggml-virtgpu: "
|
||||
#define GGML_VIRTGPU "ggml-virtgpu: "
|
||||
|
||||
// USE_ALWAYS_TRUE_SUPPORTS_OP: 1 is fast, 0 avoid micro-benchmark crashes
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include <stdint.h>
|
||||
|
||||
struct virgl_renderer_capset_apir {
|
||||
uint32_t apir_version;
|
||||
uint32_t supports_blob_resources;
|
||||
uint32_t reserved[4]; // For future expansion
|
||||
uint32_t apir_version;
|
||||
uint32_t supports_blob_resources;
|
||||
uint32_t reserved[4]; // For future expansion
|
||||
};
|
||||
|
||||
@@ -145,8 +145,31 @@ class RemotingCodebaseGenerator:
|
||||
enum_lines.append(f" APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},")
|
||||
enum_lines.append("} ApirBackendCommandType;")
|
||||
|
||||
# Generate function name mapping
|
||||
func_lines = []
|
||||
func_lines.append("static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {")
|
||||
func_lines.append(" switch (type) {")
|
||||
|
||||
current_group = None
|
||||
for func in functions:
|
||||
# Add comment for new group
|
||||
if func['group_name'] != current_group:
|
||||
func_lines.append(f" /* {func['group_description']} */")
|
||||
current_group = func['group_name']
|
||||
|
||||
# Generate clean function name without backend_ prefix
|
||||
clean_name = f"{func['group_name']}_{func['function_name']}"
|
||||
func_lines.append(f" case {func['enum_name']}:")
|
||||
func_lines.append(f" return \"{clean_name}\";")
|
||||
|
||||
func_lines.append("")
|
||||
func_lines.append(" default:")
|
||||
func_lines.append(" return \"unknown\";")
|
||||
func_lines.append(" }")
|
||||
func_lines.append("}")
|
||||
|
||||
# Full header template
|
||||
header_content = NL.join(enum_lines) + "\n"
|
||||
header_content = NL.join(enum_lines) + "\n\n" + NL.join(func_lines) + "\n"
|
||||
|
||||
return header_content
|
||||
|
||||
@@ -170,19 +193,6 @@ class RemotingCodebaseGenerator:
|
||||
|
||||
decl_lines.append(f"{signature} {func['backend_function']}({params});")
|
||||
|
||||
# Switch cases
|
||||
switch_lines = []
|
||||
current_group = None
|
||||
|
||||
for func in functions:
|
||||
if func['group_name'] != current_group:
|
||||
switch_lines.append(f" /* {func['group_description']} */")
|
||||
current_group = func['group_name']
|
||||
|
||||
deprecated = " (DEPRECATED)" if func['deprecated'] else ""
|
||||
|
||||
switch_lines.append(f" case {func['enum_name']}: return \"{func['backend_function']}{deprecated}\";")
|
||||
|
||||
# Dispatch table
|
||||
table_lines = []
|
||||
current_group = None
|
||||
@@ -201,15 +211,6 @@ class RemotingCodebaseGenerator:
|
||||
|
||||
{NL.join(decl_lines)}
|
||||
|
||||
static inline const char *backend_dispatch_command_name(ApirBackendCommandType type)
|
||||
{{
|
||||
switch (type) {{
|
||||
{NL.join(switch_lines)}
|
||||
|
||||
default: return "unknown";
|
||||
}}
|
||||
}}
|
||||
|
||||
extern "C" {{
|
||||
static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{
|
||||
{NL.join(table_lines)}
|
||||
|
||||
@@ -17,8 +17,8 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) {
|
||||
size_t cgraph_size = apir_serialize_ggml_cgraph(cgraph, cgraph_data);
|
||||
|
||||
virtgpu_shmem temp_shmem; // Local storage for large buffers
|
||||
virtgpu_shmem * shmem = &temp_shmem;
|
||||
bool using_shared_shmem = false;
|
||||
virtgpu_shmem * shmem = &temp_shmem;
|
||||
bool using_shared_shmem = false;
|
||||
|
||||
if (cgraph_size <= gpu->data_shmem.mmap_size) {
|
||||
// Lock mutex before using shared data_shmem buffer
|
||||
@@ -26,7 +26,7 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) {
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
|
||||
}
|
||||
using_shared_shmem = true;
|
||||
shmem = &gpu->data_shmem;
|
||||
shmem = &gpu->data_shmem;
|
||||
} else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) {
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);
|
||||
}
|
||||
|
||||
@@ -62,7 +62,9 @@ size_t apir_buffer_type_get_max_size(virtgpu * gpu, apir_buffer_type_host_handle
|
||||
return max_size;
|
||||
}
|
||||
|
||||
apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, size_t size) {
|
||||
apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu,
|
||||
apir_buffer_type_host_handle_t host_handle,
|
||||
size_t size) {
|
||||
apir_encoder * encoder;
|
||||
apir_decoder * decoder;
|
||||
ApirForwardReturnCode ret;
|
||||
@@ -84,7 +86,9 @@ apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_t
|
||||
return buffer_context;
|
||||
}
|
||||
|
||||
size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, const ggml_tensor * op) {
|
||||
size_t apir_buffer_type_get_alloc_size(virtgpu * gpu,
|
||||
apir_buffer_type_host_handle_t host_handle,
|
||||
const ggml_tensor * op) {
|
||||
apir_encoder * encoder;
|
||||
apir_decoder * decoder;
|
||||
ApirForwardReturnCode ret;
|
||||
|
||||
@@ -35,8 +35,8 @@ void apir_buffer_set_tensor(virtgpu * gpu,
|
||||
apir_encode_ggml_tensor(encoder, tensor);
|
||||
|
||||
virtgpu_shmem temp_shmem; // Local storage for large buffers
|
||||
virtgpu_shmem * shmem = &temp_shmem;
|
||||
bool using_shared_shmem = false;
|
||||
virtgpu_shmem * shmem = &temp_shmem;
|
||||
bool using_shared_shmem = false;
|
||||
|
||||
if (size <= gpu->data_shmem.mmap_size) {
|
||||
// Lock mutex before using shared data_shmem buffer
|
||||
@@ -44,7 +44,7 @@ void apir_buffer_set_tensor(virtgpu * gpu,
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
|
||||
}
|
||||
using_shared_shmem = true;
|
||||
shmem = &gpu->data_shmem;
|
||||
shmem = &gpu->data_shmem;
|
||||
|
||||
} else if (virtgpu_shmem_create(gpu, size, shmem)) {
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);
|
||||
@@ -86,8 +86,8 @@ void apir_buffer_get_tensor(virtgpu * gpu,
|
||||
apir_encode_ggml_tensor(encoder, tensor);
|
||||
|
||||
virtgpu_shmem temp_shmem; // Local storage for large buffers
|
||||
virtgpu_shmem * shmem = &temp_shmem;
|
||||
bool using_shared_shmem = false;
|
||||
virtgpu_shmem * shmem = &temp_shmem;
|
||||
bool using_shared_shmem = false;
|
||||
|
||||
if (size <= gpu->data_shmem.mmap_size) {
|
||||
// Lock mutex before using shared data_shmem buffer
|
||||
@@ -95,7 +95,7 @@ void apir_buffer_get_tensor(virtgpu * gpu,
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
|
||||
}
|
||||
using_shared_shmem = true;
|
||||
shmem = &gpu->data_shmem;
|
||||
shmem = &gpu->data_shmem;
|
||||
|
||||
} else if (virtgpu_shmem_create(gpu, size, shmem)) {
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);
|
||||
|
||||
@@ -26,7 +26,7 @@ char * apir_device_get_name(virtgpu * gpu) {
|
||||
REMOTE_CALL(gpu, encoder, decoder, ret);
|
||||
|
||||
const size_t string_size = apir_decode_array_size_unchecked(decoder);
|
||||
char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size);
|
||||
char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size);
|
||||
if (!string) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__);
|
||||
return NULL;
|
||||
@@ -173,7 +173,7 @@ apir_buffer_context_t apir_device_buffer_from_ptr(virtgpu * gpu, size_t size, si
|
||||
REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR);
|
||||
|
||||
if (virtgpu_shmem_create(gpu, size, &buffer_context.shmem)) {
|
||||
GGML_ABORT(GGML_VIRTGPU "Couldn't allocate the guest-host shared buffer");
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate %ldb of guest-host shared buffer", __func__, size);
|
||||
}
|
||||
|
||||
apir_encode_virtgpu_shmem_res_id(encoder, buffer_context.shmem.res_id);
|
||||
|
||||
@@ -1,29 +1,36 @@
|
||||
#include "virtgpu.h"
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
#include "virtgpu.h"
|
||||
#include "ggml-remoting.h"
|
||||
#include "backend/shared/apir_backend.h"
|
||||
#include "backend/shared/apir_cs_ggml.h"
|
||||
|
||||
#include "ggml-backend-impl.h"
|
||||
// clang-format on
|
||||
|
||||
#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \
|
||||
do { \
|
||||
int32_t forward_flag = (int32_t) apir_command_type__; \
|
||||
encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, forward_flag); \
|
||||
if (!encoder_name) { \
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \
|
||||
} \
|
||||
#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \
|
||||
int32_t REMOTE_CALL_PREPARE_forward_flag = (int32_t) apir_command_type__; \
|
||||
const char * REMOTE_CALL_PREPARE_command_name = apir_dispatch_command_name(apir_command_type__); \
|
||||
do { \
|
||||
encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, REMOTE_CALL_PREPARE_forward_flag); \
|
||||
if (!encoder_name) { \
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \
|
||||
do { \
|
||||
ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \
|
||||
if (!decoder_name) { \
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \
|
||||
} \
|
||||
if (ret_name < APIR_FORWARD_BASE_INDEX) { \
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \
|
||||
apir_forward_error(ret_name), ret_name); \
|
||||
} \
|
||||
ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \
|
||||
#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \
|
||||
do { \
|
||||
ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \
|
||||
if (!decoder_name) { \
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \
|
||||
} \
|
||||
if (ret_name < APIR_FORWARD_BASE_INDEX) { \
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \
|
||||
apir_forward_error(ret_name), ret_name); \
|
||||
} \
|
||||
ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \
|
||||
if (ret_name != 0) { \
|
||||
GGML_ABORT(GGML_VIRTGPU "backend function '%s' failed (return code: %d)", \
|
||||
REMOTE_CALL_PREPARE_command_name, ret_name); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
@@ -20,6 +20,7 @@ apir_buffer_context_t apir_device_buffer_from_ptr(struct virtgpu * gpu,
|
||||
char * apir_buffer_type_get_name(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
|
||||
size_t apir_buffer_type_get_alignment(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
|
||||
size_t apir_buffer_type_get_max_size(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
|
||||
/* apir_buffer_type_is_host is deprecated. */
|
||||
apir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu * gpu,
|
||||
apir_buffer_type_host_handle_t host_handle,
|
||||
size_t size);
|
||||
|
||||
@@ -53,9 +53,9 @@ static int virtgpu_handshake(virtgpu * gpu) {
|
||||
|
||||
if (!decoder) {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: failed to initiate the communication with the virglrenderer library. "
|
||||
"Most likely, the wrong virglrenderer library was loaded in the hypervisor.",
|
||||
__func__);
|
||||
"%s: failed to initiate the communication with the virglrenderer library. "
|
||||
"Most likely, the wrong virglrenderer library was loaded in the hypervisor.",
|
||||
__func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -65,8 +65,7 @@ static int virtgpu_handshake(virtgpu * gpu) {
|
||||
uint32_t host_minor;
|
||||
|
||||
if (ret_magic != APIR_HANDSHAKE_MAGIC) {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic,
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic,
|
||||
apir_backend_initialize_error(ret_magic));
|
||||
} else {
|
||||
apir_decode_uint32_t(decoder, &host_major);
|
||||
@@ -140,15 +139,13 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) {
|
||||
"Make sure virglrenderer is correctly configured by the hypervisor. (%s) ",
|
||||
__func__, apir_load_library_error(ret));
|
||||
} else {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)", __func__,
|
||||
apir_load_library_error(ret), ret);
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)",
|
||||
__func__, apir_load_library_error(ret), ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
GGML_LOG_INFO(GGML_VIRTGPU
|
||||
"%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__);
|
||||
GGML_LOG_INFO(GGML_VIRTGPU "%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__);
|
||||
|
||||
ApirLoadLibraryReturnCode apir_ret = (ApirLoadLibraryReturnCode) (ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX);
|
||||
|
||||
@@ -158,10 +155,11 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) {
|
||||
"Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
|
||||
__func__, apir_load_library_error(apir_ret));
|
||||
} else if (apir_ret == APIR_LOAD_LIBRARY_SYMBOL_MISSING) {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. "
|
||||
"Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
|
||||
__func__, apir_load_library_error(apir_ret));
|
||||
GGML_ABORT(
|
||||
GGML_VIRTGPU
|
||||
"%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. "
|
||||
"Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
|
||||
__func__, apir_load_library_error(apir_ret));
|
||||
} else if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: the API Remoting backend library couldn't load the GGML backend library: apir code=%d | %s)",
|
||||
@@ -169,8 +167,8 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) {
|
||||
} else {
|
||||
uint32_t lib_ret = apir_ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX;
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: the API Remoting backend library initialize its backend library: apir code=%d)", __func__,
|
||||
lib_ret);
|
||||
"%s: the API Remoting backend library failed to initialize its backend library: apir code=%d)",
|
||||
__func__, lib_ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@@ -184,55 +182,49 @@ virtgpu * create_virtgpu() {
|
||||
// Initialize mutex to protect shared data_shmem buffer
|
||||
if (mtx_init(&gpu->data_shmem_mutex, mtx_plain) != thrd_success) {
|
||||
delete gpu;
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: failed to initialize data_shmem mutex", __func__);
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize data_shmem mutex", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (virtgpu_open(gpu) != APIR_SUCCESS) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU
|
||||
"%s: failed to open the virtgpu device\n", __func__);
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to open the virtgpu device\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (virtgpu_init_capset(gpu) != APIR_SUCCESS) {
|
||||
if (gpu->use_apir_capset) {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library supports it.", __func__);
|
||||
"%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library "
|
||||
"supports it.",
|
||||
__func__);
|
||||
} else {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: failed to initialize the virtgpu Venus capset", __func__);
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu Venus capset", __func__);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (virtgpu_init_context(gpu) != APIR_SUCCESS) {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: failed to initialize the GPU context", __func__);
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the GPU context", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (virtgpu_shmem_create(gpu, SHMEM_REPLY_SIZE, &gpu->reply_shmem)) {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: failed to create the shared reply memory pages", __func__);
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared reply memory pages", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (virtgpu_shmem_create(gpu, SHMEM_DATA_SIZE, &gpu->data_shmem)) {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: failed to create the shared data memory pages", __func__);
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared data memory pages", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (virtgpu_handshake(gpu)) {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: failed to handshake with the virglrenderer library", __func__);
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to handshake with the virglrenderer library", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (virtgpu_load_library(gpu) != APIR_LOAD_LIBRARY_SUCCESS) {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: failed to load the backend library", __func__);
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to load the backend library", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@@ -243,8 +235,7 @@ static virt_gpu_result_t virtgpu_open(virtgpu * gpu) {
|
||||
drmDevicePtr devs[8];
|
||||
int count = drmGetDevices2(0, devs, ARRAY_SIZE(devs));
|
||||
if (count < 0) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU
|
||||
"%s: failed to enumerate DRM devices\n", __func__);
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to enumerate DRM devices\n", __func__);
|
||||
return APIR_ERROR_INITIALIZATION_FAILED;
|
||||
}
|
||||
|
||||
@@ -266,19 +257,17 @@ static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr d
|
||||
|
||||
int fd = open(node_path, O_RDWR | O_CLOEXEC);
|
||||
if (fd < 0) {
|
||||
GGML_ABORT(GGML_VIRTGPU
|
||||
"%s: failed to open %s", __func__, node_path);
|
||||
GGML_ABORT(GGML_VIRTGPU "%s: failed to open %s", __func__, node_path);
|
||||
return APIR_ERROR_INITIALIZATION_FAILED;
|
||||
}
|
||||
|
||||
drmVersionPtr version = drmGetVersion(fd);
|
||||
if (!version || strcmp(version->name, "virtio_gpu") || version->version_major != 0) {
|
||||
if (version) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU
|
||||
"%s: unknown DRM driver %s version %d\n", __func__, version->name, version->version_major);
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU "%s: unknown DRM driver %s version %d\n", __func__, version->name,
|
||||
version->version_major);
|
||||
} else {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU
|
||||
"%s: failed to get DRM driver version\n", __func__);
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get DRM driver version\n", __func__);
|
||||
}
|
||||
|
||||
if (version) {
|
||||
@@ -322,9 +311,8 @@ static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) {
|
||||
virtgpu_ioctl_get_caps(gpu, gpu->capset.id, gpu->capset.version, &gpu->capset.data, sizeof(gpu->capset.data));
|
||||
|
||||
if (ret) {
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU
|
||||
"%s: failed to get APIR v%d capset: %s\n",
|
||||
__func__, gpu->capset.version, strerror(errno));
|
||||
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get APIR v%d capset: %s\n", __func__, gpu->capset.version,
|
||||
strerror(errno));
|
||||
return APIR_ERROR_INITIALIZATION_FAILED;
|
||||
}
|
||||
|
||||
@@ -547,13 +535,10 @@ static void log_call_duration(long long call_duration_ns, const char * name) {
|
||||
double call_duration_s = (double) call_duration_ns / 1e9; // 1 second = 1e9 nanoseconds
|
||||
|
||||
if (call_duration_s > 1) {
|
||||
GGML_LOG_INFO(GGML_VIRTGPU
|
||||
"waited %.2fs for the %s host reply...\n", call_duration_s, name);
|
||||
GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fs for the %s host reply...\n", call_duration_s, name);
|
||||
} else if (call_duration_ms > 1) {
|
||||
GGML_LOG_INFO(GGML_VIRTGPU
|
||||
"waited %.2fms for the %s host reply...\n", call_duration_ms, name);
|
||||
GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fms for the %s host reply...\n", call_duration_ms, name);
|
||||
} else {
|
||||
GGML_LOG_INFO(GGML_VIRTGPU
|
||||
"waited %lldns for the %s host reply...\n", call_duration_ns, name);
|
||||
GGML_LOG_INFO(GGML_VIRTGPU "waited %lldns for the %s host reply...\n", call_duration_ns, name);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
#include "virtgpu-utils.h"
|
||||
#include "virtgpu-shm.h"
|
||||
#include "virtgpu-apir.h"
|
||||
@@ -23,20 +24,21 @@
|
||||
#include "apir_hw.h"
|
||||
#include <drm/virtgpu_drm.h>
|
||||
#include "venus_hw.h"
|
||||
// clang-format on
|
||||
|
||||
#ifndef VIRTGPU_DRM_CAPSET_APIR
|
||||
// Will be defined include/drm/virtgpu_drm.h when
|
||||
// https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590/diffs
|
||||
// is merged
|
||||
#define VIRTGPU_DRM_CAPSET_APIR 10
|
||||
# define VIRTGPU_DRM_CAPSET_APIR 10
|
||||
#endif
|
||||
|
||||
// Mesa/Virlgrenderer Venus internal. Only necessary during the
|
||||
// Venus->APIR transition in Virglrenderer
|
||||
#define VENUS_COMMAND_TYPE_LENGTH 331
|
||||
|
||||
#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16
|
||||
#define VIRTGPU_DRM_CAPSET_VENUS 4
|
||||
#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16
|
||||
# define VIRTGPU_DRM_CAPSET_VENUS 4
|
||||
#endif
|
||||
|
||||
typedef uint32_t virgl_renderer_capset;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user