mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-18 06:54:06 +00:00
Compare commits
80 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2054e3a8f | ||
|
|
dd52868050 | ||
|
|
6b9a52422b | ||
|
|
2f966b8ed8 | ||
|
|
cd5e3b5754 | ||
|
|
87c9efc3b2 | ||
|
|
76af40aaaa | ||
|
|
7db35a7958 | ||
|
|
a864132ba5 | ||
|
|
d38d9f0877 | ||
|
|
7fd205a8e8 | ||
|
|
2f68ce7cfd | ||
|
|
e4a71599e5 | ||
|
|
dd5e8cab51 | ||
|
|
cf659bbb8e | ||
|
|
d8b860a219 | ||
|
|
1ae74882f8 | ||
|
|
961660b8c3 | ||
|
|
74fef4129f | ||
|
|
5d8bb900bc | ||
|
|
2e76e01360 | ||
|
|
d3dc9dd898 | ||
|
|
bea04522ff | ||
|
|
0de0a01576 | ||
|
|
e58d585604 | ||
|
|
31c511a968 | ||
|
|
6d39015a74 | ||
|
|
4146d6a1a6 | ||
|
|
8da3c0e200 | ||
|
|
c22473b580 | ||
|
|
0f715b4e75 | ||
|
|
d2d931f173 | ||
|
|
2976b0374d | ||
|
|
d2a2673dd1 | ||
|
|
13002a0896 | ||
|
|
6eb208d17e | ||
|
|
9984cbb61d | ||
|
|
ce18efeaf1 | ||
|
|
16724b5b68 | ||
|
|
b52edd2558 | ||
|
|
517b7170e1 | ||
|
|
835e918d84 | ||
|
|
d261223d24 | ||
|
|
dcca0d3ab8 | ||
|
|
bacddc049a | ||
|
|
229bf68628 | ||
|
|
d7395115ba | ||
|
|
052df28b0e | ||
|
|
8b11deea46 | ||
|
|
b9ce940177 | ||
|
|
3464bdac37 | ||
|
|
e3af5563bd | ||
|
|
10fcc41290 | ||
|
|
bcf5bda6f5 | ||
|
|
3eb2be1ca5 | ||
|
|
e41bcce8f0 | ||
|
|
144a4ce824 | ||
|
|
f549b0007d | ||
|
|
9a3ea685b9 | ||
|
|
338074c383 | ||
|
|
851553ea6b | ||
|
|
85a7d8677b | ||
|
|
a8ca18b4b8 | ||
|
|
8284efc35c | ||
|
|
1c1409e131 | ||
|
|
7a0e900e36 | ||
|
|
280d97be96 | ||
|
|
3479efd112 | ||
|
|
463bbf20bf | ||
|
|
ad8d36beff | ||
|
|
c053e18a66 | ||
|
|
e1ab084803 | ||
|
|
5a4ff43e7d | ||
|
|
10640e31aa | ||
|
|
80d28f104c | ||
|
|
c55d53acec | ||
|
|
945501f5ea | ||
|
|
75cbdd3fce | ||
|
|
2b9bd9bf4e | ||
|
|
59fc1ec8e8 |
@@ -24,8 +24,9 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DGGML_BACKEND_DL=OFF \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DGGML_BLAS=ON \
|
||||
-DGGML_BLAS_VENDOR=OpenBLAS && \
|
||||
cmake --build build --config Release -j $(nproc) && \
|
||||
@@ -103,6 +104,7 @@ FROM base AS light
|
||||
WORKDIR /llama.cpp/bin
|
||||
|
||||
# Copy llama.cpp binaries and libraries
|
||||
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin
|
||||
|
||||
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
|
||||
@@ -116,6 +118,7 @@ ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
WORKDIR /llama.cpp/bin
|
||||
|
||||
# Copy llama.cpp binaries and libraries
|
||||
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
74
.github/workflows/build-linux-cross.yml
vendored
74
.github/workflows/build-linux-cross.yml
vendored
@@ -4,49 +4,49 @@ on:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
ubuntu-24-riscv64-cpu-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
# ubuntu-24-riscv64-cpu-cross:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Riscv
|
||||
run: |
|
||||
sudo dpkg --add-architecture riscv64
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - name: Setup Riscv
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture riscv64
|
||||
|
||||
# Add arch-specific repositories for non-amd64 architectures
|
||||
cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
EOF
|
||||
# # Add arch-specific repositories for non-amd64 architectures
|
||||
# cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
# deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
# EOF
|
||||
|
||||
sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
# sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
gcc-14-riscv64-linux-gnu \
|
||||
g++-14-riscv64-linux-gnu
|
||||
# sudo apt-get install -y --no-install-recommends \
|
||||
# build-essential \
|
||||
# gcc-14-riscv64-linux-gnu \
|
||||
# g++-14-riscv64-linux-gnu
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_CURL=OFF \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=riscv64 \
|
||||
-DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
|
||||
-DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
# - name: Build
|
||||
# run: |
|
||||
# cmake -B build -DLLAMA_CURL=OFF \
|
||||
# -DCMAKE_BUILD_TYPE=Release \
|
||||
# -DGGML_OPENMP=OFF \
|
||||
# -DLLAMA_BUILD_EXAMPLES=ON \
|
||||
# -DLLAMA_BUILD_TOOLS=ON \
|
||||
# -DLLAMA_BUILD_TESTS=OFF \
|
||||
# -DCMAKE_SYSTEM_NAME=Linux \
|
||||
# -DCMAKE_SYSTEM_PROCESSOR=riscv64 \
|
||||
# -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \
|
||||
# -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \
|
||||
# -DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
# -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
# -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
# cmake --build build --config Release -j $(nproc)
|
||||
|
||||
# ubuntu-24-riscv64-vulkan-cross:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
2
.github/workflows/docker.yml
vendored
2
.github/workflows/docker.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
# https://github.com/ggml-org/llama.cpp/issues/11888
|
||||
#- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false }
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||
|
||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -134,8 +134,8 @@ jobs:
|
||||
include:
|
||||
- build: 'x64'
|
||||
os: ubuntu-22.04
|
||||
- build: 's390x-z15' # z15 because our CI runners are on z15
|
||||
os: ubuntu-22.04-s390x
|
||||
- build: 's390x'
|
||||
os: ubuntu-24.04-s390x
|
||||
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
|
||||
# - build: 'arm64'
|
||||
# os: ubuntu-22.04-arm
|
||||
|
||||
@@ -65,7 +65,7 @@
|
||||
/ggml/src/ggml-impl.h @ggerganov @slaren
|
||||
/ggml/src/ggml-metal/ @ggerganov
|
||||
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
|
||||
/ggml/src/ggml-hexagon/ @max-krasnyansky
|
||||
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
|
||||
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
||||
/ggml/src/ggml-quants.* @ggerganov
|
||||
/ggml/src/ggml-rpc/ @rgerganov
|
||||
@@ -89,6 +89,7 @@
|
||||
/src/llama-model-loader.* @slaren
|
||||
/src/llama-model.* @CISC
|
||||
/src/llama-vocab.* @CISC
|
||||
/src/models/ @CISC
|
||||
/tests/ @ggerganov
|
||||
/tests/test-backend-ops.cpp @slaren
|
||||
/tests/test-thread-safety.cpp @slaren
|
||||
|
||||
@@ -2030,7 +2030,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.system_prompt.pop_back();
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
|
||||
add_opt(common_arg(
|
||||
{"--in-file"}, "FNAME",
|
||||
"an input file (repeat to specify multiple files)",
|
||||
@@ -3203,7 +3203,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
|
||||
add_opt(common_arg(
|
||||
{"--parse-special"},
|
||||
string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),
|
||||
string_format("parse special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.parse_special = true;
|
||||
}
|
||||
@@ -3248,7 +3248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
||||
add_opt(common_arg(
|
||||
{"--embd-output-format"}, "FORMAT",
|
||||
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix",
|
||||
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.embd_out = value;
|
||||
}
|
||||
|
||||
217
common/chat.cpp
217
common/chat.cpp
@@ -9,8 +9,11 @@
|
||||
#include <minja/chat-template.hpp>
|
||||
#include <minja/minja.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cctype>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
@@ -310,7 +313,6 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
|
||||
}
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = msg.reasoning_content;
|
||||
jmsg["thinking"] = msg.reasoning_content; // gpt-oss
|
||||
}
|
||||
if (!msg.tool_name.empty()) {
|
||||
jmsg["name"] = msg.tool_name;
|
||||
@@ -640,6 +642,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
|
||||
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
|
||||
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
|
||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
@@ -986,6 +989,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
||||
return data;
|
||||
}
|
||||
|
||||
|
||||
// Case-insensitive find
|
||||
static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) {
|
||||
auto it = std::search(
|
||||
haystack.begin() + pos, haystack.end(),
|
||||
needle.begin(), needle.end(),
|
||||
[](char a, char b) { return std::tolower(a) == std::tolower(b); }
|
||||
);
|
||||
return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
const auto is_json_schema_provided = !inputs.json_schema.is_null();
|
||||
const auto is_grammar_provided = !inputs.grammar.empty();
|
||||
const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
|
||||
// the logic requires potentially modifying the messages
|
||||
auto tweaked_messages = inputs.messages;
|
||||
|
||||
auto replace_json_schema_marker = [](json & messages) -> bool {
|
||||
static std::string marker1 = "force json schema.\n";
|
||||
static std::string marker2 = "force json schema.";
|
||||
|
||||
if (messages.empty() || messages.at(0).at("role") != "system") {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string content = messages.at(0).at("content");
|
||||
|
||||
for (const auto & marker : {marker1, marker2}) {
|
||||
const auto pos = ifind_string(content, marker);
|
||||
if (pos != std::string::npos) {
|
||||
content.replace(pos, marker.length(), "");
|
||||
// inject modified content back into the messages
|
||||
messages.at(0).at("content") = content;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
// Lfm2 model does not natively work with json, but can generally understand the tools structure
|
||||
//
|
||||
// Example of the pytorch dialog structure:
|
||||
// <|startoftext|><|im_start|>system
|
||||
// List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|>
|
||||
// <|im_start|>user
|
||||
// What is the current status of candidate ID 12345?<|im_end|>
|
||||
// <|im_start|>assistant
|
||||
// <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|>
|
||||
// <|im_start|>tool
|
||||
// <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|>
|
||||
// <|im_start|>assistant
|
||||
// The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|>
|
||||
//
|
||||
// For the llama server compatibility with json tools semantic,
|
||||
// the client can add "Follow json schema." line into the system message prompt to force the json output.
|
||||
//
|
||||
if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) {
|
||||
// server/utils.hpp prohibits that branch for the custom grammar anyways
|
||||
throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar");
|
||||
} else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) {
|
||||
LOG_INF("%s: Using tools to build a grammar\n", __func__);
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function.at("name")},
|
||||
}},
|
||||
{"arguments", function.at("parameters")},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
});
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!inputs.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
|
||||
builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\"");
|
||||
});
|
||||
// model has no concept of tool selection mode choice,
|
||||
// if the system prompt rendered correctly it will produce a tool call
|
||||
// the grammar goes inside the tool call body
|
||||
data.grammar_lazy = true;
|
||||
data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}};
|
||||
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
|
||||
data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS;
|
||||
} else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) {
|
||||
LOG_INF("%s: Using tools without json schema or grammar\n", __func__);
|
||||
// output those tokens
|
||||
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
|
||||
} else if (is_json_schema_provided) {
|
||||
LOG_INF("%s: Using provided json schema to build a grammar\n", __func__);
|
||||
data.grammar = json_schema_to_grammar(inputs.json_schema);
|
||||
} else if (is_grammar_provided) {
|
||||
LOG_INF("%s: Using provided grammar\n", __func__);
|
||||
data.grammar = inputs.grammar;
|
||||
} else {
|
||||
LOG_INF("%s: Using content relying on the template\n", __func__);
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
|
||||
LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str());
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
@@ -1686,7 +1809,23 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
|
||||
|
||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
auto prompt = apply(tmpl, inputs);
|
||||
|
||||
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
|
||||
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
|
||||
|
||||
if (has_reasoning_content && has_tool_calls) {
|
||||
auto adjusted_message = msg;
|
||||
adjusted_message["thinking"] = msg.at("reasoning_content");
|
||||
adjusted_messages.push_back(adjusted_message);
|
||||
} else {
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
}
|
||||
|
||||
auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
|
||||
// Check if we need to replace the return token with end token during
|
||||
// inference and without generation prompt. For more details see:
|
||||
@@ -2499,6 +2638,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
|
||||
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
|
||||
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
|
||||
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
|
||||
|
||||
// Loop through all tool calls
|
||||
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
|
||||
// Consume end marker
|
||||
builder.consume_spaces();
|
||||
if (!builder.try_consume_regex(tool_call_end_regex)) {
|
||||
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
|
||||
}
|
||||
|
||||
// Process each tool call in the array
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
for (const auto & tool_call : tool_calls_data.json) {
|
||||
if (!tool_call.is_object()) {
|
||||
throw common_chat_msg_partial_exception("Tool call must be an object");
|
||||
}
|
||||
|
||||
if (!tool_call.contains("name")) {
|
||||
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
|
||||
}
|
||||
|
||||
std::string function_name = tool_call.at("name");
|
||||
std::string arguments = "{}";
|
||||
|
||||
if (tool_call.contains("arguments")) {
|
||||
if (tool_call.at("arguments").is_object()) {
|
||||
arguments = tool_call.at("arguments").dump();
|
||||
} else if (tool_call.at("arguments").is_string()) {
|
||||
arguments = tool_call.at("arguments");
|
||||
}
|
||||
}
|
||||
|
||||
if (!builder.add_tool_call(function_name, "", arguments)) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
|
||||
}
|
||||
|
||||
// Consume any trailing whitespace after this tool call
|
||||
builder.consume_spaces();
|
||||
}
|
||||
|
||||
// Consume any remaining content after all tool calls
|
||||
auto remaining = builder.consume_rest();
|
||||
if (!string_strip(remaining).empty()) {
|
||||
builder.add_content(remaining);
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags first - this handles the main reasoning content
|
||||
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
|
||||
@@ -2748,6 +2952,12 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_apertus(tmpl, params);
|
||||
}
|
||||
|
||||
// LFM2 (w/ tools)
|
||||
if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos &&
|
||||
src.find("]<|tool_list_end|>") != std::string::npos) {
|
||||
return common_chat_params_init_lfm2(tmpl, params);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
@@ -2926,6 +3136,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
case COMMON_CHAT_FORMAT_APERTUS:
|
||||
common_chat_parse_apertus(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
|
||||
common_chat_parse_lfm2(builder);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||
}
|
||||
|
||||
@@ -116,6 +116,7 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_SEED_OSS,
|
||||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
COMMON_CHAT_FORMAT_APERTUS,
|
||||
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
@@ -601,7 +601,10 @@ private:
|
||||
}
|
||||
|
||||
std::string _resolve_ref(const std::string & ref) {
|
||||
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
|
||||
auto it = ref.find('#');
|
||||
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
|
||||
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
|
||||
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
|
||||
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
||||
_refs_being_resolved.insert(ref);
|
||||
json resolved = _refs[ref];
|
||||
@@ -774,11 +777,24 @@ public:
|
||||
std::vector<std::string> tokens = string_split(pointer, "/");
|
||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||
std::string sel = tokens[i];
|
||||
if (target.is_null() || !target.contains(sel)) {
|
||||
if (target.is_object() && target.contains(sel)) {
|
||||
target = target[sel];
|
||||
} else if (target.is_array()) {
|
||||
size_t sel_index;
|
||||
try {
|
||||
sel_index = std::stoul(sel);
|
||||
} catch (const std::invalid_argument & e) {
|
||||
sel_index = target.size();
|
||||
}
|
||||
if (sel_index >= target.size()) {
|
||||
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||
return;
|
||||
}
|
||||
target = target[sel_index];
|
||||
} else {
|
||||
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||
return;
|
||||
}
|
||||
target = target[sel];
|
||||
}
|
||||
_refs[ref] = target;
|
||||
}
|
||||
|
||||
@@ -1054,6 +1054,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
|
||||
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
|
||||
res = "granite-docling"
|
||||
if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95":
|
||||
# ref: https://huggingface.co/MiniMaxAI/MiniMax-M2
|
||||
res = "minimax-m2"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@@ -1528,7 +1531,7 @@ class MmprojModel(ModelBase):
|
||||
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
|
||||
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
|
||||
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
|
||||
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))
|
||||
|
||||
# preprocessor config
|
||||
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||
@@ -2460,18 +2463,21 @@ class ArceeModel(LlamaModel):
|
||||
)
|
||||
class LlavaVisionModel(MmprojModel):
|
||||
img_break_tok_id = -1
|
||||
use_break_tok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.hparams.get("model_type") == "pixtral":
|
||||
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
|
||||
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
||||
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
|
||||
if self.use_break_tok:
|
||||
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
|
||||
elif self.is_mistral_format:
|
||||
# hparams is already vision config here so norm_eps is only defined in global_config.
|
||||
self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
|
||||
assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
|
||||
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
|
||||
if self.use_break_tok:
|
||||
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
|
||||
logger.info(f"Image break token id: {self.img_break_tok_id}")
|
||||
@@ -3849,7 +3855,43 @@ class Qwen2MoeModel(TextModel):
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
name = name.replace("language_model.", "") # InternVL
|
||||
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
|
||||
|
||||
# handle aggregated expert tensors
|
||||
# GGUF stores dimensions reversed from PyTorch, so:
|
||||
# PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
|
||||
# Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
|
||||
# Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
|
||||
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
|
||||
mapped = f"{name}.weight" if not name.endswith(".weight") else name
|
||||
# Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
|
||||
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
|
||||
# Need PyTorch: (128, 2048, 768) [reversed of GGML]
|
||||
# So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
|
||||
permuted = data_torch.permute(0, 2, 1).contiguous()
|
||||
return [(self.map_tensor_name(mapped), permuted)]
|
||||
|
||||
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
|
||||
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
|
||||
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
|
||||
split_dim = data_torch.shape[-1] // 2
|
||||
gate = data_torch[..., :split_dim].contiguous()
|
||||
up = data_torch[..., split_dim:].contiguous()
|
||||
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
|
||||
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
|
||||
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
|
||||
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
|
||||
base_name = name.removesuffix(".weight")
|
||||
base = base_name.rsplit('.', 1)[0]
|
||||
mapped_gate = f"{base}.gate_proj.weight"
|
||||
mapped_up = f"{base}.up_proj.weight"
|
||||
perm_gate = gate.permute(0, 2, 1).contiguous()
|
||||
perm_up = up.permute(0, 2, 1).contiguous()
|
||||
return [
|
||||
(self.map_tensor_name(mapped_gate), perm_gate),
|
||||
(self.map_tensor_name(mapped_up), perm_up),
|
||||
]
|
||||
|
||||
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
|
||||
# skip visual tensors
|
||||
return []
|
||||
if name.find("experts") != -1:
|
||||
@@ -3962,6 +4004,10 @@ class Qwen3Model(Qwen2Model):
|
||||
return torch.stack([true_row, false_row], dim=0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "model.vision_" in name:
|
||||
# skip multimodal tensors
|
||||
return []
|
||||
|
||||
if self.is_rerank:
|
||||
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
|
||||
is_real_head = not self.is_tied_embeddings and "lm_head" in name
|
||||
@@ -3997,6 +4043,187 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
||||
super().set_vocab()
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
|
||||
class Qwen3VLVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
# Compute image_size if not present
|
||||
if "image_size" not in self.hparams_vision:
|
||||
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
|
||||
num_pos = self.hparams_vision.get("num_position_embeddings", 2304)
|
||||
patch_size = self.hparams_vision.get("patch_size", 16)
|
||||
# num_position_embeddings = (image_size / patch_size) ** 2
|
||||
# So image_size = sqrt(num_position_embeddings) * patch_size
|
||||
image_size = int(num_pos**0.5 * patch_size)
|
||||
self.hparams_vision["image_size"] = image_size
|
||||
|
||||
# Rename config values for compatibility
|
||||
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
|
||||
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
|
||||
|
||||
self.is_deepstack_layers = [False] * int(self.hparams_vision["num_hidden_layers"] or 0)
|
||||
for idx in self.hparams_vision.get("deepstack_visual_indexes", []):
|
||||
self.is_deepstack_layers[idx] = True
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
|
||||
if self.hparams_vision is not None:
|
||||
merge_size = self.hparams_vision.get("spatial_merge_size")
|
||||
if merge_size is not None:
|
||||
self.gguf_writer.add_vision_spatial_merge_size(int(merge_size))
|
||||
|
||||
# Use text config's rms_norm_eps for vision attention layernorm eps
|
||||
rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
|
||||
|
||||
if self.is_deepstack_layers:
|
||||
self.gguf_writer.add_vision_is_deepstack_layers(self.is_deepstack_layers)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
assert self.hparams_vision is not None
|
||||
# Skip text model tensors - they go in the text model file
|
||||
if name.startswith("model.language_model.") or name.startswith("lm_head."):
|
||||
return []
|
||||
|
||||
if name.startswith("model.visual."):
|
||||
name = name.replace("model.visual.", "visual.", 1)
|
||||
|
||||
if name.startswith("visual.deepstack_merger_list."):
|
||||
prefix, rest = name.split(".", maxsplit=3)[2:]
|
||||
# prefix is the layer index, convert to absolute clip layer index!
|
||||
idx = self.hparams_vision.get("deepstack_visual_indexes", [])[int(prefix)]
|
||||
target = rest
|
||||
|
||||
tensor_type: gguf.MODEL_TENSOR
|
||||
if target.startswith("norm."):
|
||||
tensor_type = gguf.MODEL_TENSOR.V_DS_NORM
|
||||
suffix = target.split(".", 1)[1]
|
||||
elif target.startswith("linear_fc1."):
|
||||
tensor_type = gguf.MODEL_TENSOR.V_DS_FC1
|
||||
suffix = target.split(".", 1)[1]
|
||||
elif target.startswith("linear_fc2."):
|
||||
tensor_type = gguf.MODEL_TENSOR.V_DS_FC2
|
||||
suffix = target.split(".", 1)[1]
|
||||
else:
|
||||
raise ValueError(f"Unexpected deepstack tensor: {name}")
|
||||
|
||||
new_name = self.format_tensor_name(tensor_type, idx, suffix=f".{suffix}")
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
if name.startswith("visual.merger."):
|
||||
suffix = name.split(".", 2)[2]
|
||||
if suffix.startswith("linear_fc"):
|
||||
fc_idx_str, tail = suffix.split(".", 1)
|
||||
fc_num = int(fc_idx_str.replace("linear_fc", ""))
|
||||
# Qwen3VL has linear_fc1 and linear_fc2
|
||||
# Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
|
||||
if fc_num == 1:
|
||||
fc_idx = 0
|
||||
elif fc_num == 2:
|
||||
fc_idx = 2
|
||||
else:
|
||||
raise ValueError(f"unexpected fc index {fc_num} in {name}")
|
||||
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, fc_idx, suffix=f".{tail}")
|
||||
elif suffix.startswith("norm."):
|
||||
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_POST_NORM, suffix=f".{suffix.split('.', 1)[1]}")
|
||||
else:
|
||||
raise ValueError(f"Unexpected merger tensor: {name}")
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
if name == "visual.patch_embed.proj.weight":
|
||||
# split Conv3D into Conv2Ds along temporal dimension
|
||||
c1, c2, kt, _, _ = data_torch.shape
|
||||
del c1, c2
|
||||
if kt != 2:
|
||||
raise ValueError("Current implementation only supports temporal_patch_size of 2")
|
||||
return [
|
||||
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]),
|
||||
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
|
||||
]
|
||||
|
||||
if name == "visual.patch_embed.proj.bias":
|
||||
# Include the bias - it's used by the C++ code
|
||||
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)]
|
||||
|
||||
if name.startswith("visual."):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
# Fall back to parent class for other tensors
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3VLForConditionalGeneration")
|
||||
class Qwen3VLTextModel(Qwen3Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3VL
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
|
||||
text_config = self.hparams.get("text_config", {})
|
||||
# rope_scaling is deprecated in V5, use rope_parameters instead
|
||||
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
|
||||
|
||||
if rope_scaling.get("mrope_section"):
|
||||
# mrope_section contains [time, height, width] dimensions
|
||||
mrope_section = rope_scaling["mrope_section"]
|
||||
# Pad to 4 dimensions [time, height, width, extra]
|
||||
while len(mrope_section) < 4:
|
||||
mrope_section.append(0)
|
||||
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
|
||||
|
||||
logger.info(f"MRoPE sections: {mrope_section[:4]}")
|
||||
|
||||
vision_config = self.hparams.get("vision_config", {})
|
||||
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
|
||||
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Skip vision tensors - they go in the mmproj file
|
||||
if name.startswith("model.visual."):
|
||||
return []
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3VLMoeForConditionalGeneration")
|
||||
class Qwen3VLMoeTextModel(Qwen3MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
|
||||
text_config = self.hparams.get("text_config", {})
|
||||
# rope_scaling is deprecated in V5, use rope_parameters instead
|
||||
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
|
||||
|
||||
if rope_scaling.get("mrope_section"):
|
||||
# mrope_section contains [time, height, width] dimensions
|
||||
mrope_section = rope_scaling["mrope_section"]
|
||||
# Pad to 4 dimensions [time, height, width, extra]
|
||||
while len(mrope_section) < 4:
|
||||
mrope_section.append(0)
|
||||
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
|
||||
|
||||
logger.info(f"MRoPE sections: {mrope_section[:4]}")
|
||||
|
||||
vision_config = self.hparams.get("vision_config", {})
|
||||
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
|
||||
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Skip vision tensors - they go in the mmproj file
|
||||
if name.startswith("model.visual."):
|
||||
return []
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("GPT2LMHeadModel")
|
||||
class GPT2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GPT2
|
||||
@@ -6902,6 +7129,64 @@ class DeepseekV2Model(TextModel):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("MiniMaxM2ForCausalLM")
|
||||
class MiniMaxM2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.MINIMAXM2
|
||||
_experts_cache: dict[int, dict[str, Tensor]] = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hparams["num_experts"] = self.hparams["num_local_experts"]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if self.hparams["scoring_func"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif self.hparams["scoring_func"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
|
||||
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_rope_dimension_count(self.find_hparam(["rotary_dim"]))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
if name.endswith("e_score_correction_bias"):
|
||||
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
|
||||
# merge expert weights
|
||||
if 'experts' in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
expert_cache = self._experts_cache.setdefault(bid, {})
|
||||
expert_cache[name] = data_torch
|
||||
expert_weights = ["w1", "w2", "w3"]
|
||||
|
||||
# not enough expert weights to merge
|
||||
if len(expert_cache) < n_experts * len(expert_weights):
|
||||
return []
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
for w_name in expert_weights:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
|
||||
datas.append(expert_cache[ename])
|
||||
del expert_cache[ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
del self._experts_cache[bid]
|
||||
return tensors
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Dots1ForCausalLM")
|
||||
class Dots1Model(Qwen2MoeModel):
|
||||
model_arch = gguf.MODEL_ARCH.DOTS1
|
||||
@@ -9435,6 +9720,21 @@ class PixtralModel(LlavaVisionModel):
|
||||
return super().map_tensor_name(name, try_suffixes)
|
||||
|
||||
|
||||
@ModelBase.register("LightOnOCRForConditionalGeneration")
|
||||
class LightOnOCRVisionModel(LlavaVisionModel):
|
||||
is_mistral_format = False
|
||||
use_break_tok = False
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("model.vision_encoder.", "vision_tower.")
|
||||
name = name.replace("model.vision_projection.", "multi_modal_projector.")
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("KimiVLForConditionalGeneration")
|
||||
class KimiVLModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -9471,6 +9771,144 @@ class KimiVLModel(MmprojModel):
|
||||
|
||||
return [] # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register("CogVLMForCausalLM")
|
||||
class CogVLMVisionModel(MmprojModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.COGVLM)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if not name.startswith("model.vision."):
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("CogVLMForCausalLM")
|
||||
class CogVLMModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.COGVLM
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
# block vision tensors
|
||||
if name.startswith("model.vision."):
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("JanusForConditionalGeneration")
|
||||
class JanusProModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA # reuse Llama arch
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Skip vision, aligner, and generation tensors
|
||||
skip_prefixes = (
|
||||
'model.vision_model.',
|
||||
'model.aligner.',
|
||||
'model.vqmodel.',
|
||||
'model.generation_embeddings.',
|
||||
'model.generation_aligner.',
|
||||
'model.generation_head.',
|
||||
)
|
||||
if name.startswith(skip_prefixes):
|
||||
return []
|
||||
|
||||
if name.startswith('model.language_model.'):
|
||||
name = name.replace('model.language_model.', 'model.')
|
||||
elif name.startswith('language_model.'):
|
||||
name = name.replace('language_model.', '')
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("JanusForConditionalGeneration")
|
||||
class JanusProVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
if "intermediate_size" not in self.hparams_vision:
|
||||
mlp_ratio = self.hparams_vision.get("mlp_ratio")
|
||||
hidden_size = self.hparams_vision.get("hidden_size")
|
||||
if mlp_ratio is not None and hidden_size is not None:
|
||||
self.hparams_vision["intermediate_size"] = int(round(hidden_size * mlp_ratio))
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
assert self.hparams_vision is not None
|
||||
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JANUS_PRO)
|
||||
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6))
|
||||
|
||||
hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower()
|
||||
if hidden_act == "gelu":
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
elif hidden_act == "silu":
|
||||
self.gguf_writer.add_vision_use_silu(True)
|
||||
|
||||
def _map_aligner_tensor(self, data_torch: Tensor, name: str) -> Iterable[tuple[str, Tensor]]:
|
||||
"""Map aligner tensors to projector format"""
|
||||
suffix = ".bias" if name.endswith(".bias") else ".weight"
|
||||
|
||||
if name.startswith("model.aligner."):
|
||||
local_name = name[len("model.aligner."):]
|
||||
elif name.startswith("aligner."):
|
||||
local_name = name[len("aligner."):]
|
||||
else:
|
||||
raise ValueError(f"Unsupported Janus aligner prefix: {name}")
|
||||
|
||||
if local_name.startswith("fc1."):
|
||||
mm_index = 0
|
||||
elif local_name.startswith("hidden_layers."):
|
||||
parts = local_name.split(".", 2)
|
||||
if len(parts) < 3:
|
||||
raise ValueError(f"Unexpected Janus aligner tensor name: {name}")
|
||||
mm_index = int(parts[1]) + 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported Janus aligner tensor: {name}")
|
||||
|
||||
tensor_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, mm_index, suffix=suffix)
|
||||
return [(tensor_name, data_torch)]
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
# Skip language model tensors as they will be handled by `JanusProModel`
|
||||
if name.startswith(('model.language_model.', 'language_model.')):
|
||||
return []
|
||||
|
||||
# Skip generation-related components
|
||||
skip_generation_prefixes = (
|
||||
'model.vqmodel.',
|
||||
'vqmodel.',
|
||||
'model.generation_embeddings.',
|
||||
'generation_embeddings.',
|
||||
'model.generation_aligner.',
|
||||
'generation_aligner.',
|
||||
'model.generation_head.',
|
||||
'generation_head.',
|
||||
)
|
||||
if name.startswith(skip_generation_prefixes):
|
||||
return []
|
||||
|
||||
# Handle aligner tensors
|
||||
if name.startswith(('model.aligner.', 'aligner.')):
|
||||
return list(self._map_aligner_tensor(data_torch, name))
|
||||
|
||||
# Handle vision tensors
|
||||
if name.startswith(('model.vision_model.', 'vision_model.')):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
||||
@@ -141,6 +141,7 @@ models = [
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
@@ -435,7 +436,7 @@ for model in models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||
except OSError as e:
|
||||
except (OSError, TypeError) as e:
|
||||
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
|
||||
continue # Skip this model and continue with the next one in the loop
|
||||
|
||||
|
||||
@@ -261,10 +261,12 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
||||
- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
|
||||
```bash
|
||||
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
&& cmake --build build --config Release -- -j 16
|
||||
```
|
||||
|
||||
Note: `GPU_TARGETS` is optional, omitting it will build the code for all GPUs in the current system.
|
||||
|
||||
To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system.
|
||||
|
||||
The rocWMMA library is included by default when installing the ROCm SDK using the `rocm` meta package provided by AMD. Alternatively, if you are not using the meta package, you can install the library using the `rocwmma-dev` or `rocwmma-devel` package, depending on your system's package manager.
|
||||
@@ -282,17 +284,17 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
||||
```bash
|
||||
HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -p)" \
|
||||
HIP_DEVICE_LIB_PATH=<directory-you-just-found> \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||
&& cmake --build build -- -j 16
|
||||
```
|
||||
|
||||
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):
|
||||
```bash
|
||||
set PATH=%HIP_PATH%\bin;%PATH%
|
||||
cmake -S . -B build -G Ninja -DAMDGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release
|
||||
cmake -S . -B build -G Ninja -DGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
|
||||
If necessary, adapt `GPU_TARGETS` to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
|
||||
Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`.
|
||||
|
||||
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
## Images
|
||||
We have three Docker images available for this project:
|
||||
|
||||
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
|
||||
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
|
||||
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
|
||||
|
||||
Additionally, there the following images, similar to the above:
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ Legend:
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
||||
@@ -5637,25 +5637,25 @@
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","SYCL"
|
||||
"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","1","yes","SYCL"
|
||||
"SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","SYCL"
|
||||
"SYCL0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=0.000000,broadcast=0,multi_add=0","support","1","yes","SYCL"
|
||||
|
||||
|
Can't render this file because it is too large.
|
@@ -38,6 +38,7 @@ The above command will output space-separated float values.
|
||||
| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$
|
||||
| 'json' | openai style |
|
||||
| 'json+' | add cosine similarity matrix |
|
||||
| 'raw' | plain text output |
|
||||
|
||||
### --embd-separator $"string"$
|
||||
| $"string"$ | |
|
||||
|
||||
@@ -70,6 +70,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||
}
|
||||
}
|
||||
|
||||
// plain, pipe-friendly output: one embedding per line
|
||||
static void print_raw_embeddings(const float * emb,
|
||||
int n_embd_count,
|
||||
int n_embd,
|
||||
const llama_model * model,
|
||||
enum llama_pooling_type pooling_type,
|
||||
int embd_normalize) {
|
||||
const uint32_t n_cls_out = llama_model_n_cls_out(model);
|
||||
const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK);
|
||||
const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd;
|
||||
|
||||
for (int j = 0; j < n_embd_count; ++j) {
|
||||
for (int i = 0; i < cols; ++i) {
|
||||
if (embd_normalize == 0) {
|
||||
LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
|
||||
} else {
|
||||
LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
|
||||
}
|
||||
}
|
||||
LOG("\n");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
@@ -372,6 +395,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
if (notArray) LOG("\n}\n");
|
||||
} else if (params.embd_out == "raw") {
|
||||
print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
|
||||
}
|
||||
|
||||
LOG("\n");
|
||||
|
||||
@@ -371,8 +371,17 @@ class SchemaConverter:
|
||||
raise ValueError(f'Unsupported ref {ref}')
|
||||
|
||||
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel]
|
||||
assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
if isinstance(target, list):
|
||||
try:
|
||||
sel_index = int(sel)
|
||||
except ValueError:
|
||||
raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
|
||||
assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel_index]
|
||||
else:
|
||||
assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel]
|
||||
|
||||
self._refs[ref] = target
|
||||
else:
|
||||
@@ -547,7 +556,8 @@ class SchemaConverter:
|
||||
|
||||
|
||||
def _resolve_ref(self, ref):
|
||||
ref_name = ref.split('/')[-1]
|
||||
ref_fragment = ref.split('#')[-1]
|
||||
ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
|
||||
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
||||
self._refs_being_resolved.add(ref)
|
||||
resolved = self._refs[ref]
|
||||
|
||||
@@ -242,6 +242,7 @@
|
||||
#define GGML_ROPE_TYPE_NEOX 2
|
||||
#define GGML_ROPE_TYPE_MROPE 8
|
||||
#define GGML_ROPE_TYPE_VISION 24
|
||||
#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
|
||||
|
||||
#define GGML_MROPE_SECTIONS 4
|
||||
|
||||
|
||||
@@ -308,6 +308,10 @@ function(ggml_add_cpu_backend_variant tag_name)
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
foreach (feat VXE2 NNPA)
|
||||
set(GGML_INTERNAL_${feat} OFF)
|
||||
endforeach()
|
||||
|
||||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
@@ -377,9 +381,8 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
|
||||
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
|
||||
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
|
||||
ggml_add_cpu_backend_variant(z15 Z15 VXE2)
|
||||
ggml_add_cpu_backend_variant(z16 Z16 VXE2 NNPA)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
|
||||
@@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
||||
ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
|
||||
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
|
||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
theta_scale_ne, theta_scale_nb, 1);
|
||||
|
||||
float start = 0;
|
||||
float step = 1;
|
||||
@@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
|
||||
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
|
||||
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
|
||||
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
|
||||
theta_scale_nb, GGML_MAX_DIMS);
|
||||
theta_scale_nb, 1);
|
||||
float zero_value = 0, one_value = 1;
|
||||
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
||||
aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
|
||||
|
||||
@@ -67,19 +67,30 @@
|
||||
GGML_ABORT("CANN error");
|
||||
}
|
||||
|
||||
// Thread-local variable to record the current device of this thread.
|
||||
thread_local int g_current_cann_device = -1;
|
||||
|
||||
/**
|
||||
* @brief Sets the device to be used by CANN.
|
||||
* @brief Set the CANN device to be used.
|
||||
*
|
||||
* @param device The device ID to set.
|
||||
* @param device The target device ID to set.
|
||||
*/
|
||||
void ggml_cann_set_device(const int32_t device) {
|
||||
int current_device = -1;
|
||||
aclrtGetDevice(¤t_device);
|
||||
// int current_device = -1;
|
||||
// Note: In some CANN versions, if no device has been set yet,
|
||||
// aclrtGetDevice(¤t_device) may return 0 by default.
|
||||
// aclrtGetDevice(¤t_device);
|
||||
|
||||
if (device == current_device) {
|
||||
// If the current device is already the target one, no need to switch.
|
||||
if (device == g_current_cann_device) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Switch to the new device.
|
||||
ACL_CHECK(aclrtSetDevice(device));
|
||||
|
||||
// Update the global device record.
|
||||
g_current_cann_device = device;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -504,11 +504,18 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if (GGML_VXE OR GGML_INTERNAL_VXE)
|
||||
message(STATUS "VX/VXE/VXE2 enabled")
|
||||
if (GGML_VXE OR GGML_INTERNAL_VXE2)
|
||||
message(STATUS "VXE2 enabled")
|
||||
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_VXE)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_VXE2)
|
||||
endif()
|
||||
|
||||
if (GGML_INTERNAL_NNPA)
|
||||
message(STATUS "NNPA enabled")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_NNPA)
|
||||
endif()
|
||||
|
||||
ggml_add_cpu_backend_features(${GGML_CPU_NAME} s390 ${ARCH_DEFINITIONS})
|
||||
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
|
||||
message(STATUS "Wasm detected")
|
||||
list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)
|
||||
|
||||
50
ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp
Normal file
50
ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp
Normal file
@@ -0,0 +1,50 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#if defined(__s390x__)
|
||||
#include <sys/auxv.h>
|
||||
|
||||
// find hwcap bits in asm/elf.h
|
||||
#ifndef HWCAP_VXRS_EXT2
|
||||
#define HWCAP_VXRS_EXT2 (1 << 15)
|
||||
#endif
|
||||
|
||||
#ifndef HWCAP_NNPA
|
||||
#define HWCAP_NNPA (1 << 20)
|
||||
#endif
|
||||
|
||||
struct s390x_features {
|
||||
bool has_vxe2 = false;
|
||||
bool has_nnpa = false;
|
||||
|
||||
s390x_features() {
|
||||
uint32_t hwcap = getauxval(AT_HWCAP);
|
||||
// NOTE: use hwcap2 with DFLT for z17 and later
|
||||
// uint32_t hwcap2 = getauxval(AT_HWCAP2);
|
||||
|
||||
has_vxe2 = !!(hwcap & HWCAP_VXRS_EXT2);
|
||||
has_nnpa = !!(hwcap & HWCAP_NNPA);
|
||||
}
|
||||
};
|
||||
|
||||
static int ggml_backend_cpu_s390x_score() {
|
||||
int score = 1;
|
||||
s390x_features sf;
|
||||
|
||||
// IBM z15 / LinuxONE 3
|
||||
#ifdef GGML_USE_VXE2
|
||||
if (!sf.has_vxe2) { return 0; }
|
||||
score += 1 << 1;
|
||||
#endif
|
||||
|
||||
// IBM z16 / LinuxONE 4 and z17 / LinuxONE 5
|
||||
#ifdef GGML_USE_NNPA
|
||||
if (!sf.has_nnpa) { return 0; }
|
||||
score += 1 << 2;
|
||||
#endif
|
||||
|
||||
return score;
|
||||
}
|
||||
|
||||
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_s390x_score)
|
||||
|
||||
#endif // __s390x__
|
||||
@@ -1613,13 +1613,8 @@ static void ggml_compute_forward_mul_mat_id(
|
||||
chunk_size = 64;
|
||||
}
|
||||
|
||||
#if defined(__aarch64__)
|
||||
// disable for ARM
|
||||
const bool disable_chunking = true;
|
||||
#else
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
#endif // defined(__aarch64__)
|
||||
|
||||
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
||||
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
||||
|
||||
@@ -5474,7 +5474,7 @@ static void ggml_rope_cache_init(
|
||||
}
|
||||
|
||||
static void ggml_mrope_cache_init(
|
||||
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
|
||||
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
|
||||
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
||||
float * cache, float sin_sign, float theta_scale) {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
@@ -5509,14 +5509,26 @@ static void ggml_mrope_cache_init(
|
||||
}
|
||||
|
||||
float theta = theta_t;
|
||||
if (sector >= sections[0] && sector < sec_w) {
|
||||
theta = theta_h;
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
||||
theta = theta_w;
|
||||
}
|
||||
else if (sector >= sec_w + sections[2]) {
|
||||
theta = theta_e;
|
||||
if (is_imrope) { // qwen3vl apply interleaved mrope
|
||||
if (sector % 3 == 1 && sector < 3 * sections[1]) {
|
||||
theta = theta_h;
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
|
||||
theta = theta_w;
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
|
||||
theta = theta_t;
|
||||
} else {
|
||||
theta = theta_e;
|
||||
}
|
||||
} else {
|
||||
if (sector >= sections[0] && sector < sec_w) {
|
||||
theta = theta_h;
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
||||
theta = theta_w;
|
||||
}
|
||||
else if (sector >= sec_w + sections[2]) {
|
||||
theta = theta_e;
|
||||
}
|
||||
}
|
||||
|
||||
rope_yarn(
|
||||
@@ -5589,6 +5601,7 @@ static void ggml_compute_forward_rope_f32(
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
@@ -5627,7 +5640,7 @@ static void ggml_compute_forward_rope_f32(
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_vision,
|
||||
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
@@ -5775,6 +5788,7 @@ static void ggml_compute_forward_rope_f16(
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
@@ -5813,7 +5827,7 @@ static void ggml_compute_forward_rope_f16(
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_vision,
|
||||
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
@@ -7519,8 +7533,8 @@ static void ggml_compute_forward_upscale_f32(
|
||||
float pixel_offset = 0.5f;
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
pixel_offset = 0.0f;
|
||||
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
|
||||
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
|
||||
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
||||
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
||||
}
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
@@ -7909,10 +7923,10 @@ void ggml_compute_forward_argsort(
|
||||
|
||||
// ggml_compute_forward_flash_attn_ext
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
ggml_tensor * dst,
|
||||
int ir0, int ir1) {
|
||||
const ggml_tensor * q = dst->src[0];
|
||||
const ggml_tensor * k = dst->src[1];
|
||||
const ggml_tensor * v = dst->src[2];
|
||||
@@ -7928,9 +7942,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t DK = nek0;
|
||||
const int64_t DV = nev0;
|
||||
const int64_t N = neq1;
|
||||
@@ -7964,16 +7975,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
// parallelize by q rows using ggml_vec_dot_f32
|
||||
|
||||
// total rows in q
|
||||
const int nr = neq1*neq2*neq3;
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
@@ -8000,6 +8001,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
||||
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
||||
|
||||
int ith = params->ith;
|
||||
|
||||
// loop over n_batch and n_head
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
@@ -8147,6 +8150,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * q = dst->src[0];
|
||||
const ggml_tensor * k = dst->src[1];
|
||||
const ggml_tensor * v = dst->src[2];
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int64_t DK = nek0;
|
||||
const int64_t DV = nev0;
|
||||
const int64_t N = neq1;
|
||||
|
||||
GGML_ASSERT(ne0 == DV);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
|
||||
// input tensor rows must be contiguous
|
||||
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
||||
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
||||
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
||||
|
||||
GGML_ASSERT(neq0 == DK);
|
||||
GGML_ASSERT(nek0 == DK);
|
||||
GGML_ASSERT(nev0 == DV);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// parallelize by q rows using ggml_vec_dot_f32
|
||||
|
||||
// total rows in q
|
||||
const int64_t nr = neq1*neq2*neq3;
|
||||
|
||||
// rows per thread
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
|
||||
// 4x chunks per thread
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
||||
|
||||
if (nth == 1 || nchunk < nth || disable_chunking) {
|
||||
nchunk = nth;
|
||||
}
|
||||
|
||||
if (ith == 0) {
|
||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||
ggml_threadpool_chunk_set(params->threadpool, nth);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// The number of elements in each chunk
|
||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk) {
|
||||
const int64_t ir0 = dr * current_chunk;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
||||
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_flash_attn_ext(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
@@ -1600,6 +1600,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
return false;
|
||||
}
|
||||
|
||||
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
const ggml_tensor * src1 = op->src[1];
|
||||
ggml_tensor * dst = op;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const void * src1_wdata = params->wdata;
|
||||
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
|
||||
|
||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||
if (ne11 > 3) {
|
||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||
}
|
||||
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
}
|
||||
}
|
||||
|
||||
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
const ggml_tensor * src1 = op->src[1];
|
||||
@@ -1643,31 +1669,41 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
||||
}
|
||||
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
|
||||
// 4x chunks per thread
|
||||
int64_t nr = ggml_nrows(op->src[0]);
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
||||
|
||||
if (nth == 1 || nchunk < nth || disable_chunking) {
|
||||
nchunk = nth;
|
||||
}
|
||||
|
||||
if (ith == 0) {
|
||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||
ggml_threadpool_chunk_set(params->threadpool, nth);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
const void * src1_wdata = params->wdata;
|
||||
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
|
||||
int64_t src0_start = (ith * ne01) / nth;
|
||||
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
||||
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||
if (src0_start >= src0_end) {
|
||||
return;
|
||||
}
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||
if (ne11 > 3) {
|
||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||
}
|
||||
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
while (current_chunk < nchunk) {
|
||||
int64_t src0_start = (current_chunk * ne01) / nchunk;
|
||||
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
|
||||
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||
if (src0_start >= src0_end) {
|
||||
break;
|
||||
}
|
||||
|
||||
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
|
||||
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ template<ggml_sort_order order>
|
||||
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
int row = blockIdx.x;
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
@@ -151,7 +151,7 @@ static void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
const dim3 block_dims(ncols_pad, 1, 1);
|
||||
const dim3 block_nums(1, nrows, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
|
||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
||||
|
||||
@@ -224,6 +224,11 @@ static const char * cu_get_error_str(CUresult err) {
|
||||
#define AMD_MFMA_AVAILABLE
|
||||
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
||||
|
||||
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
|
||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#define VOLTA_MMA_AVAILABLE
|
||||
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
|
||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
#define TURING_MMA_AVAILABLE
|
||||
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
@@ -278,7 +283,10 @@ static bool amd_mfma_available(const int cc) {
|
||||
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
|
||||
}
|
||||
|
||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||
static bool volta_mma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
|
||||
}
|
||||
|
||||
static bool turing_mma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||
}
|
||||
@@ -625,8 +633,11 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
||||
// and a shift:
|
||||
//
|
||||
// n/d = (mulhi(n, mp) + n) >> L;
|
||||
static const uint3 init_fastdiv_values(uint32_t d) {
|
||||
GGML_ASSERT(d != 0);
|
||||
static const uint3 init_fastdiv_values(uint64_t d_64) {
|
||||
GGML_ASSERT(d_64 != 0);
|
||||
GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
uint32_t d = (uint32_t)d_64;
|
||||
|
||||
// compute L = ceil(log2(d));
|
||||
uint32_t L = 0;
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "ggml-cuda/mmq.cuh"
|
||||
#include "ggml-cuda/mmvf.cuh"
|
||||
#include "ggml-cuda/mmvq.cuh"
|
||||
#include "ggml-cuda/moe-expert-reduce.cuh"
|
||||
#include "ggml-cuda/norm.cuh"
|
||||
#include "ggml-cuda/opt-step-adamw.cuh"
|
||||
#include "ggml-cuda/opt-step-sgd.cuh"
|
||||
@@ -50,6 +51,7 @@
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml-cuda/set.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||
#include "ggml.h"
|
||||
@@ -2416,6 +2418,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_cuda_op_set_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET:
|
||||
ggml_cuda_op_set(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
@@ -2494,6 +2499,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_XIELU:
|
||||
ggml_cuda_op_xielu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
ggml_cuda_op_floor(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
ggml_cuda_op_ceil(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
ggml_cuda_op_round(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
ggml_cuda_op_trunc(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -2974,7 +2991,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
|
||||
|
||||
if (ops.size() == topk_moe_ops_with_norm.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
|
||||
|
||||
@@ -2993,7 +3010,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2, node_idx + 5 })) {
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
|
||||
|
||||
@@ -3114,9 +3131,20 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
||||
if (!use_cuda_graph || cuda_graph_update_required) {
|
||||
|
||||
[[maybe_unused]] int prev_i = 0;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
|
||||
#ifdef GGML_CUDA_DEBUG
|
||||
const int nodes_fused = i - prev_i - 1;
|
||||
prev_i = i;
|
||||
if (nodes_fused > 0) {
|
||||
GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
@@ -3154,6 +3182,31 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_MUL) {
|
||||
int current_node = i + 1;
|
||||
int num_views = 0;
|
||||
int num_adds = 0;
|
||||
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
|
||||
num_views++;
|
||||
current_node++;
|
||||
}
|
||||
|
||||
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_ADD &&
|
||||
num_adds < num_views - 1) {
|
||||
num_adds++;
|
||||
current_node++;
|
||||
}
|
||||
|
||||
if (num_adds == num_views - 1 && num_views > 0) {
|
||||
ggml_tensor * dst_node = cgraph->nodes[current_node - 1];
|
||||
if (ggml_cuda_should_use_moe_expert_reduce(cgraph, i, current_node)) {
|
||||
ggml_cuda_op_moe_expert_reduce(*cuda_ctx, node->src[0], node->src[1], dst_node);
|
||||
i += num_views + num_adds;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
int n_fuse = 0;
|
||||
ggml_op ops[8];
|
||||
@@ -3728,6 +3781,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
@@ -3842,6 +3899,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
{
|
||||
const ggml_type t = op->type;
|
||||
return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
|
||||
t == op->src[0]->type &&
|
||||
t == op->src[1]->type;
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
|
||||
@@ -18,6 +18,10 @@
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
// On Volta each warp is doing 4 8x8 mma operations in parallel.
|
||||
// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
|
||||
// However, the i indices in this file are by default permuted to simplify the index calculations.
|
||||
// #define GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
|
||||
#if CUDART_VERSION >= 11080
|
||||
|
||||
@@ -73,6 +77,15 @@ namespace ggml_cuda_mma {
|
||||
static constexpr int ne = I * J / 64;
|
||||
T x[ne] = {0};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 64 && J == 2) return true;
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 32 && J == 4) return true;
|
||||
if (I == 16 && J == 16) return true;
|
||||
if (I == 32 && J == 32) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||
return threadIdx.x % 16;
|
||||
@@ -85,7 +98,8 @@ namespace ggml_cuda_mma {
|
||||
} else if constexpr (I == 32 && J == 32) {
|
||||
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,22 +115,67 @@ namespace ggml_cuda_mma {
|
||||
} else if constexpr (I == 32 && J == 32) {
|
||||
return threadIdx.x % 32;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
static constexpr int ne = I * J / 32;
|
||||
T x[ne] = {0};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 32 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 32 && J == 8) {
|
||||
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (l & 2) | (threadIdx.x % 2);
|
||||
#else
|
||||
return (l & 2) | (threadIdx.x & ~2);
|
||||
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 32 && J == 8) {
|
||||
return (threadIdx.x & 2) | (l & (4 + 1));
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / 32;
|
||||
T x[ne] = {0};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 4) return true;
|
||||
if (I == 8 && J == 8) return true;
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 16 && J == 16) return true;
|
||||
if (I == 32 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && (J == 4 || J == 8)) {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 8 && J == 8) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l / 2) * 8 + threadIdx.x / 4;
|
||||
return ((l / 2) * 8) | (threadIdx.x / 4);
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
|
||||
return (((l / 2) % 2) * 8) | (threadIdx.x / 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,13 +183,16 @@ namespace ggml_cuda_mma {
|
||||
if constexpr (I == 8 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 8 && J == 8) {
|
||||
return 4 * l + threadIdx.x % 4;
|
||||
return (l * 4) | (threadIdx.x % 4);
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return 2 * (threadIdx.x % 4) + l % 2;
|
||||
return ((threadIdx.x % 4) * 2) | (l % 2);
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
|
||||
return ((l / 4) * 8) | ((threadIdx.x % 4) * 2) | (l % 2);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
@@ -140,32 +202,83 @@ namespace ggml_cuda_mma {
|
||||
struct tile<I_, J_, half2> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
static constexpr int ne = I == 8 && J == 8 ? I * J / (WARP_SIZE/4) : I * J / WARP_SIZE;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 8) return true;
|
||||
if (I == 32 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
return (((threadIdx.x % 16) / 4) * 8) | ((threadIdx.x / 16) * 4) | (threadIdx.x % 4);
|
||||
#else
|
||||
return threadIdx.x;
|
||||
#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr ((I == 8 || I == 32) && J == 8) {
|
||||
return l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 4) return true;
|
||||
if (I == 8 && J == 8) return true;
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 16 && J == 16) return true;
|
||||
if (I == 32 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return l * 8 + threadIdx.x / 4;
|
||||
return (l * 8) | (threadIdx.x / 4);
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l % 2) * 8 + threadIdx.x / 4;
|
||||
return ((l % 2) * 8) | (threadIdx.x / 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return ((l / 4) * 16) | ((l % 2) * 8) | (threadIdx.x / 4);
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return l * 4 + threadIdx.x % 4;
|
||||
return (l * 4) | (threadIdx.x % 4);
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l / 2) * 4 + threadIdx.x % 4;
|
||||
return ((l / 2) * 4) | (threadIdx.x % 4);
|
||||
} else if constexpr (I == 32 && J == 8) {
|
||||
return ((l & 2) * 2) | (threadIdx.x % 4);
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
@@ -175,27 +288,36 @@ namespace ggml_cuda_mma {
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 8 && J == 8) return true;
|
||||
if (I == 16 && J == 4) return true;
|
||||
if (I == 16 && J == 8) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return threadIdx.x / 4;
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return l * 8 + threadIdx.x / 4;
|
||||
return (l * 8) | (threadIdx.x / 4);
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l % 2) * 8 + threadIdx.x / 4;
|
||||
return ((l % 2) * 8) | (threadIdx.x / 4);
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 8 && J == 8) {
|
||||
return l * 4 + threadIdx.x % 4;
|
||||
return (l * 4) | (threadIdx.x % 4);
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return threadIdx.x % 4;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return (l / 2) * 4 + threadIdx.x % 4;
|
||||
return ((l / 2) * 4) | (threadIdx.x % 4);
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -263,8 +385,12 @@ namespace ggml_cuda_mma {
|
||||
: "=r"(xi[0]), "=r"(xi[1])
|
||||
: "l"(xs));
|
||||
#else
|
||||
load_generic(xs0, stride);
|
||||
GGML_UNUSED(t);
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
GGML_UNUSED_VARS(t, xs0, stride);
|
||||
NO_DEVICE_CODE;
|
||||
#else
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
@@ -277,11 +403,35 @@ namespace ggml_cuda_mma {
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
|
||||
: "l"(xs));
|
||||
#else
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
GGML_UNUSED_VARS(t, xs0, stride);
|
||||
NO_DEVICE_CODE;
|
||||
#else
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<32, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#if 1
|
||||
// TODO: more generic handling
|
||||
static_assert(sizeof(T) == 4, "bad type size");
|
||||
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
|
||||
ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
|
||||
#else
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // 1
|
||||
#else
|
||||
tile<16, 8, T> * t16 = (tile<16, 8, T> *) &t;
|
||||
load_ldmatrix(t16[0], xs0 + 0*stride, stride);
|
||||
load_ldmatrix(t16[1], xs0 + 16*stride, stride);
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix_trans(
|
||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
@@ -546,4 +696,43 @@ namespace ggml_cuda_mma {
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, int J, int K>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
|
||||
tile<16, J, T1> * D16 = (tile<16, J, T1> *) &D;
|
||||
tile<16, K, T2> * A16 = (tile<16, K, T2> *) &A;
|
||||
mma(D16[0], A16[0], B);
|
||||
mma(D16[1], A16[1], B);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<32, 8, float> & D, const tile<32, 8, half2> & A, const tile<8, 8, half2> & B) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
const int * Axi = (const int *) A.x;
|
||||
const int * Bxi = (const int *) B.x;
|
||||
int * Dxi = (int *) D.x;
|
||||
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
|
||||
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
|
||||
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[4]), "r"(Axi[5]), "r"(Bxi[4]), "r"(Bxi[5]));
|
||||
asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
|
||||
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
|
||||
#else
|
||||
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
|
||||
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
|
||||
mma(D16[0], A16[0], B);
|
||||
mma(D16[1], A16[1], B);
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
case GGML_TYPE_F16:
|
||||
return turing_mma_available(cc);
|
||||
return volta_mma_available(cc) || turing_mma_available(cc);
|
||||
case GGML_TYPE_BF16:
|
||||
return ampere_mma_available(cc);
|
||||
default:
|
||||
|
||||
@@ -28,9 +28,19 @@ static __global__ void mul_mat_f(
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
|
||||
if (!I_16_supported && !I_32_supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
@@ -232,7 +242,6 @@ static __global__ void mul_mat_f(
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
|
||||
//This kernel is for larger batch sizes of mul_mat_id
|
||||
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
||||
@@ -245,9 +254,19 @@ static __global__ void mul_mat_f_ids(
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||
|
||||
if (!I_16_supported && !I_32_supported) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
|
||||
|
||||
typedef tile<I_preferred, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<I_preferred, 8, float> tile_C;
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
@@ -533,7 +552,8 @@ void mul_mat_f_cuda(
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<16, 8, T> tile_A_16;
|
||||
typedef tile<32, 8, T> tile_A_32;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
|
||||
GGML_ASSERT(ncols_x % 2 == 0);
|
||||
@@ -544,7 +564,8 @@ void mul_mat_f_cuda(
|
||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||||
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
|
||||
int64_t nwarps_best = 1;
|
||||
@@ -559,7 +580,7 @@ void mul_mat_f_cuda(
|
||||
}
|
||||
|
||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
|
||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||
|
||||
@@ -343,6 +343,10 @@ static __global__ void mul_mat_vec_f(
|
||||
}
|
||||
|
||||
dst[tid*stride_col_dst + row] = value;
|
||||
|
||||
if constexpr (!has_fusion) {
|
||||
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename type_acc, int ncols_dst, int block_size>
|
||||
|
||||
@@ -190,12 +190,30 @@ static __global__ void mul_mat_vec_q(
|
||||
|
||||
const uint32_t channel_bias = ids ? channel_x : channel_dst;
|
||||
|
||||
float x_biases[ncols_dst] = { 0.0f };
|
||||
float gate_biases[ncols_dst] = { 0.0f };
|
||||
if constexpr (has_fusion) {
|
||||
if (use_bias) {
|
||||
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||
// 1. Hide latency by prefetching bias and gate here
|
||||
// 2. load only on threads that won't die after partial sum calculation
|
||||
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
|
||||
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (use_gate_bias) {
|
||||
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
|
||||
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,12 +301,12 @@ static __global__ void mul_mat_vec_q(
|
||||
float result = tmp[j][threadIdx.x];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_bias) {
|
||||
result += x_bias[j*stride_col_dst + threadIdx.x];
|
||||
result += x_biases[j];
|
||||
}
|
||||
if (use_gate) {
|
||||
float gate_value = tmp_gate[j][threadIdx.x];
|
||||
if (use_gate_bias) {
|
||||
gate_value += gate_bias[j*stride_col_dst + threadIdx.x];
|
||||
gate_value += gate_biases[j];
|
||||
}
|
||||
switch (active_glu) {
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
@@ -310,6 +328,10 @@ static __global__ void mul_mat_vec_q(
|
||||
dst[j*stride_col_dst + threadIdx.x] = result;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!has_fusion) {
|
||||
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);
|
||||
}
|
||||
}
|
||||
|
||||
static std::pair<dim3, dim3> calc_launch_params(
|
||||
|
||||
168
ggml/src/ggml-cuda/moe-expert-reduce.cu
Normal file
168
ggml/src/ggml-cuda/moe-expert-reduce.cu
Normal file
@@ -0,0 +1,168 @@
|
||||
#include "moe-expert-reduce.cuh"
|
||||
|
||||
// This kernel is a fusion of the expert weight reduce, common in MoE models
|
||||
|
||||
template <int n_expert_used_template>
|
||||
__global__ void moe_expert_reduce_cuda(const float * __restrict__ experts,
|
||||
const float * __restrict__ weights,
|
||||
float * __restrict__ dst,
|
||||
const int n_expert_used,
|
||||
const int n_cols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
if (col >= n_cols) {
|
||||
return;
|
||||
}
|
||||
|
||||
experts += row * n_cols * n_expert_used;
|
||||
weights += row * n_expert_used;
|
||||
dst += row * n_cols;
|
||||
|
||||
float acc = 0.f;
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
for (int expert = 0; expert < n_expert_used; ++expert) {
|
||||
ggml_cuda_mad(acc, experts[col], weights[expert]);
|
||||
experts += n_cols;
|
||||
}
|
||||
dst[col] = acc;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_expert_used_template; ++i) {
|
||||
ggml_cuda_mad(acc, experts[col], weights[i]);
|
||||
experts += n_cols;
|
||||
}
|
||||
dst[col] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
static void launch_moe_expert_reduce(ggml_backend_cuda_context & ctx,
|
||||
const float * experts,
|
||||
const float * weights,
|
||||
float * dst,
|
||||
const int n_expert_used,
|
||||
const int n_cols,
|
||||
const int n_rows) {
|
||||
const int block_size = 32;
|
||||
|
||||
const int n_blocks_x = n_rows;
|
||||
const int n_blocks_y = (n_cols + block_size - 1) / block_size;
|
||||
|
||||
dim3 block_dims(block_size);
|
||||
dim3 grid_dims(n_blocks_x, n_blocks_y);
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
switch (n_expert_used) {
|
||||
case 1:
|
||||
moe_expert_reduce_cuda<1>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 2:
|
||||
moe_expert_reduce_cuda<2>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 4:
|
||||
moe_expert_reduce_cuda<4>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 6:
|
||||
moe_expert_reduce_cuda<6>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 8:
|
||||
moe_expert_reduce_cuda<8>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 16:
|
||||
moe_expert_reduce_cuda<16>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 32:
|
||||
moe_expert_reduce_cuda<32>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 64:
|
||||
moe_expert_reduce_cuda<64>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 128:
|
||||
moe_expert_reduce_cuda<128>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
default:
|
||||
moe_expert_reduce_cuda<0>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index) {
|
||||
const ggml_tensor * mul = cgraph->nodes[start_index];
|
||||
|
||||
if (mul->op != GGML_OP_MUL || !ggml_is_contiguous(mul->src[0]) || !ggml_is_contiguous(mul->src[1])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int current_node = start_index + 1;
|
||||
size_t current_offset = 0;
|
||||
|
||||
std::vector<const ggml_tensor *> view_nodes;
|
||||
//check if all are views of the expert in increasing order
|
||||
while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
|
||||
const ggml_tensor * node = cgraph->nodes[current_node];
|
||||
if (node->view_src != mul) {
|
||||
return false;
|
||||
}
|
||||
if (node->view_offs < current_offset) {
|
||||
return false;
|
||||
}
|
||||
current_offset = node->view_offs;
|
||||
current_node++;
|
||||
view_nodes.push_back(node);
|
||||
}
|
||||
|
||||
//check if all the adds are in increasing order
|
||||
const ggml_tensor * prev_add_src = view_nodes.empty() ? nullptr : view_nodes[0];
|
||||
int num_adds = 0;
|
||||
int num_views = view_nodes.size();
|
||||
while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_ADD) {
|
||||
const ggml_tensor * add_node = cgraph->nodes[current_node];
|
||||
|
||||
bool is_first_op_ok = num_views > num_adds ? add_node->src[0] == prev_add_src : false;
|
||||
bool is_second_op_ok = num_views > num_adds ? add_node->src[1] == view_nodes[num_adds + 1] : false;
|
||||
|
||||
if (!is_first_op_ok || !is_second_op_ok) {
|
||||
return false;
|
||||
}
|
||||
prev_add_src = add_node;
|
||||
|
||||
num_adds++;
|
||||
current_node++;
|
||||
}
|
||||
|
||||
if (num_views != num_adds + 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * experts,
|
||||
const ggml_tensor * weights,
|
||||
ggml_tensor * dst) {
|
||||
const int n_rows = experts->ne[2];
|
||||
const int n_expert_used = experts->ne[1];
|
||||
const int n_cols = experts->ne[0];
|
||||
|
||||
GGML_ASSERT(experts->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(experts));
|
||||
GGML_ASSERT(ggml_is_contiguous(weights));
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float * experts_d = (const float *) experts->data;
|
||||
const float * weights_d = (const float *) weights->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
launch_moe_expert_reduce(ctx, experts_d, weights_d, dst_d, n_expert_used, n_cols, n_rows);
|
||||
}
|
||||
11
ggml/src/ggml-cuda/moe-expert-reduce.cuh
Normal file
11
ggml/src/ggml-cuda/moe-expert-reduce.cuh
Normal file
@@ -0,0 +1,11 @@
|
||||
#include "common.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * experts,
|
||||
const ggml_tensor * weights,
|
||||
ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index);
|
||||
@@ -125,7 +125,7 @@ template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_multi(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
||||
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
@@ -152,17 +152,29 @@ static __global__ void rope_multi(
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < sections.v[0]) {
|
||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
||||
if (is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
} else {
|
||||
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < sections.v[0]) {
|
||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
@@ -276,7 +288,7 @@ template<bool forward, typename T>
|
||||
static void rope_multi_cuda(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
@@ -287,11 +299,11 @@ static void rope_multi_cuda(
|
||||
if (freq_factors == nullptr) {
|
||||
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||
} else {
|
||||
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,6 +381,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
@@ -406,11 +419,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_multi_cuda<forward>(
|
||||
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_multi_cuda<forward>(
|
||||
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
@@ -4,30 +4,53 @@
|
||||
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
|
||||
|
||||
// Generic quantized set_rows kernel template
|
||||
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
|
||||
static __global__ void k_set_rows_quant(
|
||||
const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
|
||||
static __global__ void k_set_rows_quant(const float * __restrict__ src0,
|
||||
const idx_t * __restrict__ src1,
|
||||
block_type * __restrict__ dst,
|
||||
const int64_t ne_total,
|
||||
const int64_t ne10,
|
||||
const int64_t ne11,
|
||||
const int64_t ne12,
|
||||
const int64_t ne13,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t s10,
|
||||
const int64_t s11,
|
||||
const int64_t s12,
|
||||
const int64_t s1,
|
||||
const int64_t s2,
|
||||
const int64_t s3,
|
||||
const uint3 ne00,
|
||||
const uint3 ne01,
|
||||
const uint3 ne02,
|
||||
const uint3 ne11_fd,
|
||||
const uint3 ne12_fd) {
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i_base = i * qk;
|
||||
const int64_t i03 = i_base / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
uint32_t tmp = (uint32_t) i_base;
|
||||
uint2 div_mod;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
div_mod = fast_div_modulo(tmp, ne00);
|
||||
const int64_t i00 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne01);
|
||||
const int64_t i01 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne02);
|
||||
const int64_t i02 = div_mod.y;
|
||||
const int64_t i03 = div_mod.x;
|
||||
|
||||
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
|
||||
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant(
|
||||
quantize_func(src_block, dst_block);
|
||||
|
||||
GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12);
|
||||
GGML_UNUSED(ne13);
|
||||
}
|
||||
|
||||
@@ -71,40 +96,65 @@ static void set_rows_cuda_quant(
|
||||
const int64_t s2 = nb2;
|
||||
const int64_t s3 = nb3;
|
||||
|
||||
if (ne_total > 0) {
|
||||
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
|
||||
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
|
||||
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
|
||||
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||||
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
||||
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
||||
|
||||
k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
|
||||
ne01_fd, ne02_fd, ne11_fd, ne12_fd);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename src_t, typename idx_t, typename dst_t>
|
||||
static __global__ void k_set_rows(
|
||||
const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
template <typename src_t, typename idx_t, typename dst_t>
|
||||
static __global__ void k_set_rows(const src_t * __restrict__ src0,
|
||||
const idx_t * __restrict__ src1,
|
||||
dst_t * __restrict__ dst,
|
||||
const int64_t ne_total,
|
||||
const int64_t ne10,
|
||||
const int64_t ne11,
|
||||
const int64_t ne12,
|
||||
const int64_t ne13,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t s10,
|
||||
const int64_t s11,
|
||||
const int64_t s12,
|
||||
const int64_t s1,
|
||||
const int64_t s2,
|
||||
const int64_t s3,
|
||||
const uint3 ne00,
|
||||
const uint3 ne01,
|
||||
const uint3 ne02,
|
||||
const uint3 ne11_fd,
|
||||
const uint3 ne12_fd) {
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i03 = i / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
uint32_t tmp = (uint32_t) i;
|
||||
uint2 div_mod;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
div_mod = fast_div_modulo(tmp, ne00);
|
||||
const int64_t i00 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne01);
|
||||
const int64_t i01 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne02);
|
||||
const int64_t i02 = div_mod.y;
|
||||
const int64_t i03 = div_mod.x;
|
||||
|
||||
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
|
||||
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
@@ -115,6 +165,8 @@ static __global__ void k_set_rows(
|
||||
dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
||||
|
||||
GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12);
|
||||
GGML_UNUSED(ne13);
|
||||
}
|
||||
|
||||
@@ -144,14 +196,16 @@ static void set_rows_cuda(
|
||||
const int64_t s2 = nb2/sizeof(dst_t);
|
||||
const int64_t s3 = nb3/sizeof(dst_t);
|
||||
|
||||
if (ne_total > 0) {
|
||||
k_set_rows<<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
|
||||
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
|
||||
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
|
||||
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||||
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
||||
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
||||
|
||||
k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
|
||||
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
|
||||
ne11_fd, ne12_fd);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
39
ggml/src/ggml-cuda/set.cu
Normal file
39
ggml/src/ggml-cuda/set.cu
Normal file
@@ -0,0 +1,39 @@
|
||||
#include "set.cuh"
|
||||
#include "cpy.cuh"
|
||||
|
||||
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
|
||||
GGML_ASSERT(src1->type == src0->type);
|
||||
GGML_ASSERT(dst ->type == src0->type);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
const size_t nb1 = ((int32_t *) dst->op_params)[0];
|
||||
const size_t nb2 = ((int32_t *) dst->op_params)[1];
|
||||
const size_t nb3 = ((int32_t *) dst->op_params)[2];
|
||||
const size_t offset = ((int32_t *) dst->op_params)[3];
|
||||
const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
|
||||
|
||||
if (!inplace) {
|
||||
ggml_cuda_cpy(ctx, src0, dst);
|
||||
}
|
||||
|
||||
ggml_tensor dst_view = *dst;
|
||||
dst_view.data = (void *)((char *)dst->data + offset);
|
||||
dst_view.ne[0] = src1->ne[0];
|
||||
dst_view.ne[1] = src1->ne[1];
|
||||
dst_view.ne[2] = src1->ne[2];
|
||||
dst_view.ne[3] = src1->ne[3];
|
||||
|
||||
dst_view.nb[0] = ggml_element_size(dst);
|
||||
dst_view.nb[1] = nb1;
|
||||
dst_view.nb[2] = nb2;
|
||||
dst_view.nb[3] = nb3;
|
||||
|
||||
ggml_cuda_cpy(ctx, src1, &dst_view);
|
||||
}
|
||||
7
ggml/src/ggml-cuda/set.cuh
Normal file
7
ggml/src/ggml-cuda/set.cuh
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_SET_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -85,6 +85,22 @@ static __device__ __forceinline__ float op_elu(float x) {
|
||||
return (x > 0.f) ? x : expm1f(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_floor(float x) {
|
||||
return floorf(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_ceil(float x) {
|
||||
return ceilf(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_round(float x) {
|
||||
return round(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_trunc(float x) {
|
||||
return trunc(x);
|
||||
}
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
@@ -201,6 +217,22 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_elu>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_floor>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_ceil>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_round>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_trunc>(ctx, dst);
|
||||
}
|
||||
/* gated ops */
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
|
||||
@@ -63,6 +63,14 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -126,8 +126,8 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
float pixel_offset = 0.5f;
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
||||
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
||||
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
|
||||
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
|
||||
pixel_offset = 0.0f;
|
||||
}
|
||||
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
|
||||
@@ -211,12 +211,15 @@ static inline void hex_format_op_names(char * str, const struct ggml_tensor * t)
|
||||
// ** backend sessions
|
||||
|
||||
struct ggml_hexagon_session {
|
||||
ggml_hexagon_session(int dev_id) noexcept(false);
|
||||
ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false);
|
||||
~ggml_hexagon_session() noexcept(true);
|
||||
|
||||
void allocate(int dev_id) noexcept(false);
|
||||
void release() noexcept(true);
|
||||
|
||||
void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false);
|
||||
void flush();
|
||||
|
||||
ggml_backend_buffer_type buffer_type;
|
||||
ggml_backend_buffer_type repack_buffer_type;
|
||||
|
||||
@@ -237,15 +240,37 @@ struct ggml_hexagon_session {
|
||||
uint32_t prof_pkts;
|
||||
};
|
||||
|
||||
// Packet callback
|
||||
static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * context) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(context);
|
||||
void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) {
|
||||
// Bump pending flag (cleared in the session::flush once we get the responce)
|
||||
this->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(this->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
n_bufs, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000 // Timeout
|
||||
);
|
||||
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
|
||||
}
|
||||
|
||||
if (sync) {
|
||||
flush();
|
||||
}
|
||||
}
|
||||
|
||||
// Flush HTP response queue i.e wait for all outstanding requests to complete
|
||||
void ggml_hexagon_session::flush() {
|
||||
dspqueue_t q = this->queue;
|
||||
|
||||
// Repeatedly read packets from the queue until it's empty. We don't
|
||||
// necessarily get a separate callback for each packet, and new packets
|
||||
// may arrive while we're processing the previous one.
|
||||
|
||||
while (1) {
|
||||
while (this->op_pending) {
|
||||
struct htp_general_rsp rsp;
|
||||
uint32_t rsp_size;
|
||||
uint32_t flags;
|
||||
@@ -253,22 +278,23 @@ static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * contex
|
||||
struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
uint32_t n_bufs;
|
||||
|
||||
// Read packet from queue
|
||||
int err = dspqueue_read_noblock(queue, &flags,
|
||||
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
|
||||
&n_bufs, // Number of buffer references
|
||||
bufs, // Buffer references
|
||||
sizeof(rsp), // Max message length
|
||||
&rsp_size, // Message length
|
||||
(uint8_t *) &rsp);
|
||||
// Read response packet from queue
|
||||
int err = dspqueue_read(q, &flags,
|
||||
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
|
||||
&n_bufs, // Number of buffer references
|
||||
bufs, // Buffer references
|
||||
sizeof(rsp), // Max message length
|
||||
&rsp_size, // Message length
|
||||
(uint8_t *) &rsp,
|
||||
1000000); // Timeout
|
||||
|
||||
if (err == AEE_EWOULDBLOCK) {
|
||||
// Consumed all packets available for now
|
||||
return;
|
||||
if (err == AEE_EEXPIRED) {
|
||||
// TODO: might need to bail out if the HTP is stuck on something
|
||||
continue;
|
||||
}
|
||||
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: dspqueue_read_noblock failed: 0x%08x\n", (unsigned) err);
|
||||
GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err);
|
||||
}
|
||||
|
||||
// Basic sanity checks
|
||||
@@ -281,21 +307,15 @@ static void htp_packet_callback(dspqueue_t queue, AEEResult error, void * contex
|
||||
// TODO: handle errors
|
||||
}
|
||||
|
||||
// FIXME: update profiling implementation
|
||||
sess->prof_usecs = rsp.prof_usecs;
|
||||
sess->prof_cycles = rsp.prof_cycles;
|
||||
sess->prof_pkts = rsp.prof_pkts;
|
||||
// TODO: update profiling implementation, currently only works for opt_opsync mode
|
||||
this->prof_usecs = rsp.prof_usecs;
|
||||
this->prof_cycles = rsp.prof_cycles;
|
||||
this->prof_pkts = rsp.prof_pkts;
|
||||
|
||||
sess->op_pending--; // atomic dec
|
||||
this->op_pending--; // atomic dec
|
||||
}
|
||||
}
|
||||
|
||||
// Error callback - simply terminates with an error. Used where we don't
|
||||
// expect errors.
|
||||
[[noreturn]] static void htp_error_callback(dspqueue_t queue, AEEResult error, void * context) {
|
||||
GGML_ABORT("ggml-hex: dspcall general error 0x%x: for queue %p\n", error, (void *) queue);
|
||||
}
|
||||
|
||||
// ** backend buffers
|
||||
|
||||
struct ggml_backend_hexagon_buffer_type_context {
|
||||
@@ -656,6 +676,15 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size)
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
|
||||
// Ensure we don't try to read more data than is available in the source buffer 'data'
|
||||
// or write more than the tensor can hold.
|
||||
const size_t total_tensor_size = (size_t)nrows * row_size;
|
||||
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
|
||||
|
||||
// Calculate how many full rows and how many remaining bytes we need to process.
|
||||
const int64_t n_full_rows = n_bytes_to_copy / row_size;
|
||||
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
|
||||
|
||||
void * buf_pd = ggml_aligned_malloc(row_size_pd);
|
||||
GGML_ASSERT(buf_pd != NULL);
|
||||
|
||||
@@ -667,7 +696,8 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size)
|
||||
|
||||
init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
|
||||
|
||||
for (int64_t i = 0; i < nrows; i++) {
|
||||
// 1. Process all the full rows
|
||||
for (int64_t i = 0; i < n_full_rows; i++) {
|
||||
const uint8_t * src = (const uint8_t *) data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
|
||||
|
||||
@@ -676,6 +706,25 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size)
|
||||
memcpy(dst, buf_rp, row_size);
|
||||
}
|
||||
|
||||
// 2. Process the final, potentially partial, row
|
||||
if (n_rem_bytes > 0) {
|
||||
const int64_t i = n_full_rows;
|
||||
const uint8_t * src = (const uint8_t *) data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
|
||||
|
||||
// re-init the row because we are potentially copying a partial row
|
||||
init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);
|
||||
|
||||
// Copy only the remaining bytes from the source.
|
||||
memcpy(buf_pd, src, n_rem_bytes);
|
||||
|
||||
// Repack the entire buffer
|
||||
repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
|
||||
|
||||
// Write only the corresponding remaining bytes to the destination tensor.
|
||||
memcpy(dst, buf_rp, n_rem_bytes);
|
||||
}
|
||||
|
||||
ggml_aligned_free(buf_pd, row_size_pd);
|
||||
ggml_aligned_free(buf_rp, row_size_rp);
|
||||
}
|
||||
@@ -688,6 +737,14 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
|
||||
// Ensure we don't try to copy more data than the tensor actually contains.
|
||||
const size_t total_tensor_size = (size_t)nrows * row_size;
|
||||
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
|
||||
|
||||
// Calculate how many full rows and how many remaining bytes we need to process.
|
||||
const int64_t n_full_rows = n_bytes_to_copy / row_size;
|
||||
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
|
||||
|
||||
void * buf_pd = ggml_aligned_malloc(row_size_pd);
|
||||
GGML_ASSERT(buf_pd != NULL);
|
||||
|
||||
@@ -699,7 +756,8 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
|
||||
|
||||
memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
|
||||
|
||||
for (int64_t i = 0; i < nrows; i++) {
|
||||
// 1. Process all the full rows
|
||||
for (int64_t i = 0; i < n_full_rows; i++) {
|
||||
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) data + (i * row_size);
|
||||
|
||||
@@ -708,6 +766,20 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
|
||||
memcpy(dst, buf_rp, row_size);
|
||||
}
|
||||
|
||||
// 2. Process the final, potentially partial, row
|
||||
if (n_rem_bytes > 0) {
|
||||
const int64_t i = n_full_rows;
|
||||
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) data + (i * row_size);
|
||||
|
||||
// We still need to read and unpack the entire source row because quantization is block-based.
|
||||
memcpy(buf_pd, src, row_size);
|
||||
unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
|
||||
|
||||
// But we only copy the remaining number of bytes to the destination.
|
||||
memcpy(dst, buf_rp, n_rem_bytes);
|
||||
}
|
||||
|
||||
ggml_aligned_free(buf_pd, row_size_pd);
|
||||
ggml_aligned_free(buf_rp, row_size_rp);
|
||||
}
|
||||
@@ -930,6 +1002,15 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size)
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
|
||||
// Ensure we don't try to read more data than is available in the source buffer 'data'
|
||||
// or write more than the tensor can hold.
|
||||
const size_t total_tensor_size = (size_t)nrows * row_size;
|
||||
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
|
||||
|
||||
// Calculate how many full rows and how many remaining bytes we need to process.
|
||||
const int64_t n_full_rows = n_bytes_to_copy / row_size;
|
||||
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
|
||||
|
||||
void * buf_pd = ggml_aligned_malloc(row_size_pd);
|
||||
GGML_ASSERT(buf_pd != NULL);
|
||||
|
||||
@@ -941,7 +1022,8 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size)
|
||||
|
||||
init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
|
||||
|
||||
for (int64_t i = 0; i < nrows; i++) {
|
||||
// 1. Process all the full rows
|
||||
for (int64_t i = 0; i < n_full_rows; i++) {
|
||||
const uint8_t * src = (const uint8_t *) data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
|
||||
|
||||
@@ -950,6 +1032,25 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size)
|
||||
memcpy(dst, buf_rp, row_size);
|
||||
}
|
||||
|
||||
// 2. Process the final, potentially partial, row
|
||||
if (n_rem_bytes > 0) {
|
||||
const int64_t i = n_full_rows;
|
||||
const uint8_t * src = (const uint8_t *) data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
|
||||
|
||||
// re-init the row because we are potentially copying a partial row
|
||||
init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);
|
||||
|
||||
// Copy only the remaining bytes from the source.
|
||||
memcpy(buf_pd, src, n_rem_bytes);
|
||||
|
||||
// Repack the entire buffer
|
||||
repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
|
||||
|
||||
// Write only the corresponding remaining bytes to the destination tensor.
|
||||
memcpy(dst, buf_rp, n_rem_bytes);
|
||||
}
|
||||
|
||||
ggml_aligned_free(buf_pd, row_size_pd);
|
||||
ggml_aligned_free(buf_rp, row_size_rp);
|
||||
}
|
||||
@@ -962,6 +1063,14 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size)
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
|
||||
// Ensure we don't try to copy more data than the tensor actually contains.
|
||||
const size_t total_tensor_size = (size_t)nrows * row_size;
|
||||
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
|
||||
|
||||
// Calculate how many full rows and how many remaining bytes we need to process.
|
||||
const int64_t n_full_rows = n_bytes_to_copy / row_size;
|
||||
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
|
||||
|
||||
void * buf_pd = ggml_aligned_malloc(row_size_pd);
|
||||
GGML_ASSERT(buf_pd != NULL);
|
||||
|
||||
@@ -973,7 +1082,8 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size)
|
||||
|
||||
memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
|
||||
|
||||
for (int64_t i = 0; i < nrows; i++) {
|
||||
// 1. Process all the full rows
|
||||
for (int64_t i = 0; i < n_full_rows; i++) {
|
||||
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) data + (i * row_size);
|
||||
|
||||
@@ -982,6 +1092,20 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size)
|
||||
memcpy(dst, buf_rp, row_size);
|
||||
}
|
||||
|
||||
// 2. Process the final, potentially partial, row
|
||||
if (n_rem_bytes > 0) {
|
||||
const int64_t i = n_full_rows;
|
||||
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) data + (i * row_size);
|
||||
|
||||
// We still need to read and unpack the entire source row because quantization is block-based.
|
||||
memcpy(buf_pd, src, row_size);
|
||||
unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
|
||||
|
||||
// But we only copy the remaining number of bytes to the destination.
|
||||
memcpy(dst, buf_rp, n_rem_bytes);
|
||||
}
|
||||
|
||||
ggml_aligned_free(buf_pd, row_size_pd);
|
||||
ggml_aligned_free(buf_rp, row_size_rp);
|
||||
}
|
||||
@@ -1229,6 +1353,15 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
|
||||
// Ensure we don't try to read more data than is available in the source buffer 'data'
|
||||
// or write more than the tensor can hold.
|
||||
const size_t total_tensor_size = (size_t)nrows * row_size;
|
||||
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
|
||||
|
||||
// Calculate how many full rows and how many remaining bytes we need to process.
|
||||
const int64_t n_full_rows = n_bytes_to_copy / row_size;
|
||||
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
|
||||
|
||||
void * buf_pd = ggml_aligned_malloc(row_size_pd);
|
||||
GGML_ASSERT(buf_pd != NULL);
|
||||
|
||||
@@ -1240,7 +1373,8 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si
|
||||
|
||||
init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
|
||||
|
||||
for (int64_t i = 0; i < nrows; i++) {
|
||||
// 1. Process all the full rows
|
||||
for (int64_t i = 0; i < n_full_rows; i++) {
|
||||
const uint8_t * src = (const uint8_t *) data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
|
||||
|
||||
@@ -1249,6 +1383,25 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si
|
||||
memcpy(dst, buf_rp, row_size);
|
||||
}
|
||||
|
||||
// 2. Process the final, potentially partial, row
|
||||
if (n_rem_bytes > 0) {
|
||||
const int64_t i = n_full_rows;
|
||||
const uint8_t * src = (const uint8_t *) data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
|
||||
|
||||
// re-init the row because we are potentially copying a partial row
|
||||
init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);
|
||||
|
||||
// Copy only the remaining bytes from the source.
|
||||
memcpy(buf_pd, src, n_rem_bytes);
|
||||
|
||||
// Repack the entire buffer (partial data + zero padding).
|
||||
repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
|
||||
|
||||
// Write only the corresponding remaining bytes to the destination tensor.
|
||||
memcpy(dst, buf_rp, n_rem_bytes);
|
||||
}
|
||||
|
||||
ggml_aligned_free(buf_pd, row_size_pd);
|
||||
ggml_aligned_free(buf_rp, row_size_rp);
|
||||
}
|
||||
@@ -1261,6 +1414,14 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
|
||||
// Ensure we don't try to copy more data than the tensor actually contains.
|
||||
const size_t total_tensor_size = (size_t)nrows * row_size;
|
||||
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
|
||||
|
||||
// Calculate how many full rows and how many remaining bytes we need to process.
|
||||
const int64_t n_full_rows = n_bytes_to_copy / row_size;
|
||||
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
|
||||
|
||||
void * buf_pd = ggml_aligned_malloc(row_size_pd);
|
||||
GGML_ASSERT(buf_pd != NULL);
|
||||
|
||||
@@ -1272,7 +1433,8 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si
|
||||
|
||||
memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
|
||||
|
||||
for (int64_t i = 0; i < nrows; i++) {
|
||||
// 1. Process all the full rows
|
||||
for (int64_t i = 0; i < n_full_rows; i++) {
|
||||
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) data + (i * row_size);
|
||||
|
||||
@@ -1281,6 +1443,20 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si
|
||||
memcpy(dst, buf_rp, row_size);
|
||||
}
|
||||
|
||||
// 2. Process the final, potentially partial, row
|
||||
if (n_rem_bytes > 0) {
|
||||
const int64_t i = n_full_rows;
|
||||
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) data + (i * row_size);
|
||||
|
||||
// We still need to read and unpack the entire source row because the format is block-based.
|
||||
memcpy(buf_pd, src, row_size);
|
||||
unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
|
||||
|
||||
// But we only copy the remaining number of bytes to the destination to respect the size limit.
|
||||
memcpy(dst, buf_rp, n_rem_bytes);
|
||||
}
|
||||
|
||||
ggml_aligned_free(buf_pd, row_size_pd);
|
||||
ggml_aligned_free(buf_rp, row_size_rp);
|
||||
}
|
||||
@@ -1299,19 +1475,19 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
switch (tensor->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(size == ggml_nbytes(tensor));
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
repack_q4_0_q4x4x2(tensor, data, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_Q8_0:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(size == ggml_nbytes(tensor));
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
repack_q8_0_q8x4x2(tensor, data, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_MXFP4:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(size == ggml_nbytes(tensor));
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
repack_mxfp4_mxfp4x4x2(tensor, data, size);
|
||||
break;
|
||||
|
||||
@@ -1335,19 +1511,19 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
switch (tensor->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(size == ggml_nbytes(tensor));
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
repack_q4x4x2_q4_0(data, tensor, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_Q8_0:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(size == ggml_nbytes(tensor));
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
repack_q8x4x2_q8_0(data, tensor, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_MXFP4:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(size == ggml_nbytes(tensor));
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
repack_mxfp4x4x2_mxfp4(data, tensor, size);
|
||||
break;
|
||||
|
||||
@@ -1564,7 +1740,8 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
|
||||
0, // Flags
|
||||
128 * 1024, // Request queue size (in bytes)
|
||||
64 * 1024, // Response queue size (in bytes)
|
||||
htp_packet_callback, htp_error_callback,
|
||||
nullptr, // Read packet callback (we handle reads explicitly)
|
||||
nullptr, // Error callback (we handle errors during reads)
|
||||
(void *) this, // Callback context
|
||||
&queue);
|
||||
if (err != 0) {
|
||||
@@ -1631,10 +1808,13 @@ void ggml_hexagon_session::release() noexcept(true) {
|
||||
}
|
||||
}
|
||||
|
||||
ggml_hexagon_session::ggml_hexagon_session(int dev_id) noexcept(false) {
|
||||
ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) {
|
||||
buffer_type.context = nullptr;
|
||||
repack_buffer_type.context = nullptr;
|
||||
|
||||
buffer_type.device = dev;
|
||||
repack_buffer_type.device = dev;
|
||||
|
||||
try {
|
||||
allocate(dev_id);
|
||||
|
||||
@@ -2202,7 +2382,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = DSPQUEUE_BUFFER_FLAG_REF;
|
||||
bufs[0].flags = 0;
|
||||
|
||||
// Second buffer Input Activations. This is a buffer that the CPU
|
||||
// writes and the DSP reads, so we'll need to flush CPU caches and
|
||||
@@ -2212,8 +2392,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Third buffer Output Activations. We'll handle DSP
|
||||
@@ -2224,7 +2403,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
|
||||
bufs[2].ptr = dst->data;
|
||||
bufs[2].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[2].size = ggml_nbytes(dst);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
|
||||
// Primary DSP session from the src0 (normally weight) tensor
|
||||
auto sess = src0_buf->sess;
|
||||
@@ -2252,27 +2431,7 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags)
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
3, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000 // Timeout
|
||||
);
|
||||
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, 3, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -2328,7 +2487,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = DSPQUEUE_BUFFER_FLAG_REF;
|
||||
bufs[0].flags = 0;
|
||||
|
||||
// Second buffer Input Activations. This is a buffer that the CPU
|
||||
// writes and the DSP reads, so we'll need to flush CPU caches and
|
||||
@@ -2338,8 +2497,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Third buffer expert IDs. This is a buffer that the CPU
|
||||
@@ -2350,8 +2508,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
||||
bufs[2].ptr = src2->data;
|
||||
bufs[2].offset = (uint8_t *) src2->data - src2_buf->base;
|
||||
bufs[2].size = ggml_nbytes(src2);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Forth buffer Output Activations. We'll handle DSP
|
||||
@@ -2362,7 +2519,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
||||
bufs[3].ptr = dst->data;
|
||||
bufs[3].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[3].size = ggml_nbytes(dst);
|
||||
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
|
||||
// Primary DSP session from the src0 (normally weight) tensor
|
||||
auto sess = src0_buf->sess;
|
||||
@@ -2391,27 +2548,7 @@ static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flag
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
4, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000 // Timeout
|
||||
);
|
||||
|
||||
if (err != 0) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, 4, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -2484,8 +2621,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
|
||||
// Second buffer = Second Operand of Binary op
|
||||
@@ -2497,8 +2633,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Third buffer = Output Activations. We'll handle DSP
|
||||
@@ -2509,7 +2644,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[2].ptr = dst->data;
|
||||
bufs[2].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[2].size = ggml_nbytes(dst);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
ggml_hexagon_session * sess = src0_buf->sess;
|
||||
@@ -2537,26 +2672,7 @@ static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
3, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000); // Timeout
|
||||
|
||||
if (0 != err) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, 3, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -2621,8 +2737,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[0].ptr = src0->data;
|
||||
bufs[0].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[0].size = ggml_nbytes(src0);
|
||||
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
|
||||
// Second buffer = experts bias
|
||||
@@ -2630,8 +2745,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[1].ptr = src1->data;
|
||||
bufs[1].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[1].size = ggml_nbytes(src1);
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Third buffer = activated experts
|
||||
@@ -2639,8 +2753,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[2].ptr = src2->data;
|
||||
bufs[2].offset = (uint8_t *) src2->data - src2_buf->base;
|
||||
bufs[2].size = ggml_nbytes(src2);
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
|
||||
// Forth buffer = output activations
|
||||
@@ -2648,7 +2761,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[3].ptr = dst->data;
|
||||
bufs[3].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[3].size = ggml_nbytes(dst);
|
||||
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
ggml_hexagon_session * sess = src0_buf->sess;
|
||||
@@ -2678,26 +2791,7 @@ static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) {
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
4, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000); // Timeout
|
||||
|
||||
if (0 != err) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, 4, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -2795,8 +2889,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = src0->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src0);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
++n_bufs;
|
||||
|
||||
@@ -2811,8 +2904,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = src1->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src1);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
++n_bufs;
|
||||
}
|
||||
@@ -2827,7 +2919,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = dst->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(dst);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
++n_bufs;
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
@@ -2860,26 +2952,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
n_bufs, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000); // Timeout
|
||||
|
||||
if (0 != err) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, n_bufs, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -2953,8 +3026,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = src0->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src0);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP;
|
||||
++n_bufs;
|
||||
|
||||
@@ -2968,8 +3040,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = src1->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src1);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
++n_bufs;
|
||||
|
||||
@@ -2984,8 +3055,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = src2->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) src2->data - src2_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(src2);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | // Take a reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP
|
||||
++n_bufs;
|
||||
}
|
||||
@@ -3000,7 +3070,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
|
||||
bufs[n_bufs].ptr = dst->data;
|
||||
bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base;
|
||||
bufs[n_bufs].size = ggml_nbytes(dst);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_REF | DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER);
|
||||
++n_bufs;
|
||||
|
||||
// Primary DSP session from the src0 tensor
|
||||
@@ -3033,26 +3103,7 @@ static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) {
|
||||
}
|
||||
|
||||
if ((opt_opmask & HTP_OPMASK_QUEUE)) {
|
||||
// Bump pending flag (cleared in the callback once we get the responce)
|
||||
sess->op_pending++; // atomic inc
|
||||
|
||||
int err = dspqueue_write(sess->queue,
|
||||
0, // flags - the framework will autoset this
|
||||
n_bufs, // number of buffers
|
||||
bufs, // buffer references
|
||||
sizeof(req),
|
||||
(const uint8_t *) &req, // Message
|
||||
1000000); // Timeout
|
||||
|
||||
if (0 != err) {
|
||||
GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", sess->name.c_str(), (unsigned) err);
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_opsync) {
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->enqueue(req, bufs, n_bufs, opt_opsync);
|
||||
}
|
||||
|
||||
t2 = ggml_time_us();
|
||||
@@ -3197,9 +3248,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
}
|
||||
|
||||
// Wait until all pending ops complete
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->flush();
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
@@ -3210,9 +3259,7 @@ static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) {
|
||||
HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str());
|
||||
|
||||
// Wait until all pending ops complete
|
||||
while (sess->op_pending) {
|
||||
;
|
||||
}
|
||||
sess->flush();
|
||||
}
|
||||
|
||||
struct node_info {
|
||||
@@ -3628,7 +3675,7 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
|
||||
devices[i].iface = ggml_backend_hexagon_device_i;
|
||||
devices[i].reg = reg;
|
||||
try {
|
||||
devices[i].context = new ggml_hexagon_session(i);
|
||||
devices[i].context = new ggml_hexagon_session(i, &devices[i]);
|
||||
} catch (std::exception const &exc) {
|
||||
GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i);
|
||||
devices[i].context = nullptr;
|
||||
|
||||
@@ -395,28 +395,14 @@ static void proc_matmul_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
size_t n_bufs) {
|
||||
// Prep response buffer structs (needed for error responses, etc)
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
@@ -444,41 +430,21 @@ static void proc_matmul_req(struct htp_context * ctx,
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_matmul_id_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
size_t n_bufs) {
|
||||
// Prep response buffer structs (needed for error responses, etc)
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[3].fd = bufs[3].fd;
|
||||
rsp_bufs[3].ptr = bufs[3].ptr;
|
||||
rsp_bufs[3].size = bufs[3].size;
|
||||
rsp_bufs[3].offset = bufs[3].offset;
|
||||
rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
rsp_bufs[0].fd = bufs[3].fd;
|
||||
rsp_bufs[0].ptr = bufs[3].ptr;
|
||||
rsp_bufs[0].size = bufs[3].size;
|
||||
rsp_bufs[0].offset = bufs[3].offset;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
@@ -508,32 +474,18 @@ static void proc_matmul_id_req(struct htp_context * ctx,
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
rsp_bufs[0].fd = bufs[2].fd;
|
||||
rsp_bufs[0].ptr = bufs[2].ptr;
|
||||
rsp_bufs[0].offset = bufs[2].offset;
|
||||
rsp_bufs[0].size = bufs[2].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
@@ -561,38 +513,18 @@ static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * r
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[3].fd = bufs[3].fd;
|
||||
rsp_bufs[3].ptr = bufs[3].ptr;
|
||||
rsp_bufs[3].offset = bufs[3].offset;
|
||||
rsp_bufs[3].size = bufs[3].size;
|
||||
rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
rsp_bufs[0].fd = bufs[3].fd;
|
||||
rsp_bufs[0].ptr = bufs[3].ptr;
|
||||
rsp_bufs[0].offset = bufs[3].offset;
|
||||
rsp_bufs[0].size = bufs[3].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
@@ -622,26 +554,18 @@ static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * r
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
@@ -669,7 +593,7 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 2, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_activations_req(struct htp_context * ctx,
|
||||
@@ -677,33 +601,16 @@ static void proc_activations_req(struct htp_context * ctx,
|
||||
struct dspqueue_buffer * bufs,
|
||||
uint32_t n_bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
int write_idx = 1;
|
||||
if (3 == n_bufs) {
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
write_idx = 2;
|
||||
}
|
||||
int write_idx = (n_bufs == 3) ? 2 : 1;
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||
rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
rsp_bufs[0].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[0].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[0].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[0].size = bufs[write_idx].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
@@ -742,7 +649,7 @@ static void proc_activations_req(struct htp_context * ctx,
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_rope_req(struct htp_context * ctx,
|
||||
@@ -750,39 +657,16 @@ static void proc_rope_req(struct htp_context * ctx,
|
||||
struct dspqueue_buffer * bufs,
|
||||
uint32_t n_bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
int write_idx = 2;
|
||||
if (4 == n_bufs) {
|
||||
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||
rsp_bufs[write_idx].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
write_idx++;
|
||||
}
|
||||
int write_idx = (n_bufs == 4) ? 3 : 2;
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||
rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
rsp_bufs[0].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[0].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[0].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[0].size = bufs[write_idx].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
@@ -819,7 +703,7 @@ static void proc_rope_req(struct htp_context * ctx,
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
|
||||
@@ -29,10 +29,11 @@ if (CXX_IS_HIPCC)
|
||||
endif()
|
||||
else()
|
||||
# Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
|
||||
if(AMDGPU_TARGETS AND NOT GPU_TARGETS)
|
||||
set(GPU_TARGETS ${AMDGPU_TARGETS})
|
||||
endif()
|
||||
if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||
set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS})
|
||||
elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
|
||||
endif()
|
||||
cmake_minimum_required(VERSION 3.21)
|
||||
enable_language(HIP)
|
||||
|
||||
@@ -682,6 +682,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <array>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
@@ -697,6 +698,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
||||
}
|
||||
|
||||
// Return true if the edges in the graph match expectations.
|
||||
inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
|
||||
int start_idx,
|
||||
std::initializer_list<std::array<int, 3>> edges) {
|
||||
for (const auto & edge : edges) {
|
||||
int dst_node = edge[0];
|
||||
int src_idx = edge[1];
|
||||
int src_node = edge[2];
|
||||
if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// expose GGUF internals for test code
|
||||
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||
|
||||
@@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_ne02=%d", base, ne02);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
@@ -1332,11 +1332,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_neox) {
|
||||
snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
|
||||
} else if (is_mrope && !is_vision) {
|
||||
} else if ((is_mrope || is_imrope) && !is_vision) {
|
||||
GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
|
||||
snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
|
||||
} else if (is_vision) {
|
||||
@@ -1346,14 +1347,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
|
||||
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
|
||||
}
|
||||
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -707,6 +707,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
if (op->src[0]->ne[0] != 32 &&
|
||||
op->src[0]->ne[0] != 40 &&
|
||||
op->src[0]->ne[0] != 64 &&
|
||||
op->src[0]->ne[0] != 72 &&
|
||||
op->src[0]->ne[0] != 80 &&
|
||||
op->src[0]->ne[0] != 96 &&
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
|
||||
@@ -76,6 +76,7 @@
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
|
||||
#define FC_MUL_MV 600
|
||||
#define FC_MUL_MM 700
|
||||
#define FC_ROPE 800
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPTG 8
|
||||
|
||||
@@ -3709,6 +3709,8 @@ template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
|
||||
#endif
|
||||
|
||||
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
@@ -3889,14 +3891,26 @@ kernel void kernel_rope_multi(
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base;
|
||||
if (sector < args.sect_0) {
|
||||
theta_base = (float) pos[i2];
|
||||
} else if (sector < sec_w01) {
|
||||
theta_base = (float) pos[i2 + args.ne02];
|
||||
} else if (sector < sec_w012) {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 2];
|
||||
if (FC_rope_is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
|
||||
theta_base = (float) pos[i2 + args.ne02 * 1];
|
||||
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
|
||||
theta_base = (float) pos[i2 + args.ne02 * 2];
|
||||
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
|
||||
theta_base = (float) pos[i2 + args.ne02 * 0];
|
||||
} else { // e
|
||||
theta_base = (float) pos[i2 + args.ne02 * 3];
|
||||
}
|
||||
} else {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 3];
|
||||
if (sector < args.sect_0) {
|
||||
theta_base = (float) pos[i2];
|
||||
} else if (sector < sec_w01) {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 1];
|
||||
} else if (sector < sec_w012) {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 2];
|
||||
} else {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 3];
|
||||
}
|
||||
}
|
||||
// end of mrope
|
||||
|
||||
@@ -5348,6 +5362,7 @@ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, hal
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
|
||||
@@ -5360,6 +5375,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
|
||||
@@ -5373,6 +5389,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
||||
@@ -5386,6 +5403,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
|
||||
@@ -5398,6 +5416,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
|
||||
@@ -5410,6 +5429,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
|
||||
@@ -5422,6 +5442,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
|
||||
@@ -5434,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
|
||||
|
||||
@@ -6156,8 +6156,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3));
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = (float)(ne0 - 1) / (ne00 - 1);
|
||||
sf1 = (float)(ne1 - 1) / (ne01 - 1);
|
||||
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
||||
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
||||
pixel_offset = 0.0f;
|
||||
}
|
||||
|
||||
|
||||
@@ -79,8 +79,8 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (loadc_a + l < ne01) {
|
||||
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
|
||||
@@ -94,7 +94,7 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (loadc_b + l < ne11) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
|
||||
@@ -79,7 +79,7 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (loadc_a + l < ne01) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
|
||||
@@ -94,7 +94,7 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (loadc_b + l < ne11) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
|
||||
@@ -78,7 +78,7 @@ kernel void kernel_mul_mm_q8_0_f32_l4_lm(
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (loadc_a + l < ne01) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 8;
|
||||
int iqs = idx % 8;
|
||||
@@ -101,7 +101,7 @@ kernel void kernel_mul_mm_q8_0_f32_l4_lm(
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (loadc_b + l < ne11) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
|
||||
@@ -32,8 +32,10 @@
|
||||
#include "pad.hpp"
|
||||
#include "quantize.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "roll.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "set_rows.hpp"
|
||||
#include "ssm_conv.hpp"
|
||||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "wkv.hpp"
|
||||
|
||||
@@ -42,13 +42,16 @@
|
||||
#include "ggml-sycl/backend.hpp"
|
||||
#include "ggml-sycl/common.hpp"
|
||||
#include "ggml-sycl/element_wise.hpp"
|
||||
#include "ggml-sycl/norm.hpp"
|
||||
#include "ggml-sycl/presets.hpp"
|
||||
#include "ggml-sycl/gemm.hpp"
|
||||
#include "ggml-sycl/set_rows.hpp"
|
||||
#include "ggml-sycl/set.hpp"
|
||||
#include "ggml-sycl/sycl_hw.hpp"
|
||||
#include "ggml-sycl/getrows.hpp"
|
||||
#include "ggml-sycl/repeat_back.hpp"
|
||||
#include "ggml-sycl/quantize.hpp"
|
||||
#include "ggml-sycl/ssm_conv.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
static bool g_sycl_loaded = false;
|
||||
@@ -2615,6 +2618,10 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_repeat_back(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
@@ -2631,6 +2638,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||
ggml_sycl_op_rms_norm(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_rms_norm_back(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_l2_norm(ctx, dst);
|
||||
@@ -3679,6 +3691,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_REPEAT:
|
||||
ggml_sycl_repeat(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
ggml_sycl_repeat_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_sycl_get_rows(ctx, dst);
|
||||
break;
|
||||
@@ -3818,6 +3833,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
ggml_sycl_leaky_relu(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
ggml_sycl_rms_norm_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
ggml_sycl_rms_norm(ctx, dst);
|
||||
break;
|
||||
@@ -3913,6 +3931,11 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_sycl_ssm_conv(ctx, dst);
|
||||
case GGML_OP_ROLL:
|
||||
ggml_sycl_roll(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARANGE:
|
||||
ggml_sycl_arange(ctx, dst);
|
||||
break;
|
||||
@@ -4516,6 +4539,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
}
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type == GGML_TYPE_F32;
|
||||
}
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_NONE:
|
||||
@@ -4552,6 +4580,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_RMS_NORM:
|
||||
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
||||
case GGML_OP_SCALE:
|
||||
return true;
|
||||
case GGML_OP_CONT:
|
||||
@@ -4586,6 +4616,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
return true;
|
||||
case GGML_OP_SSM_CONV:
|
||||
return op->type == GGML_TYPE_F32 &&
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
op->src[1]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ROLL:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ARANGE:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
default:
|
||||
|
||||
@@ -480,6 +480,162 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); // dz
|
||||
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); // x
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
float eps = 1e-5f;
|
||||
std::memcpy(&eps, dst->op_params, sizeof(float));
|
||||
if (!(eps > 0.0f) || !std::isfinite(eps)) eps = 1e-5f;
|
||||
|
||||
const float * g_base = static_cast<const float *>(dst->src[0]->data); // dz
|
||||
const float * x_base = static_cast<const float *>(dst->src[1]->data); // x
|
||||
float * dx_base = static_cast< float *>(dst->data);
|
||||
|
||||
const int64_t D = dst->ne[0];
|
||||
const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3;
|
||||
const int64_t N = ggml_nrows(dst);
|
||||
if (D == 0 || N == 0) return;
|
||||
|
||||
const ggml_tensor *G = dst->src[0];
|
||||
const ggml_tensor *X = dst->src[1];
|
||||
const int ts = (int) ggml_type_size(X->type);
|
||||
GGML_ASSERT((size_t) X->nb[0] == (size_t) ts);
|
||||
GGML_ASSERT((size_t) G->nb[0] == (size_t) ts);
|
||||
GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts);
|
||||
|
||||
const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts;
|
||||
const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts;
|
||||
const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts;
|
||||
|
||||
dpct::queue_ptr q = ctx.stream();
|
||||
|
||||
// work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D
|
||||
const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device];
|
||||
auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; };
|
||||
int wg_cap = 256;
|
||||
if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg);
|
||||
int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min<int64_t>(D, wg_cap), WARP_SIZE), wg_cap));
|
||||
|
||||
// FP32 path: per-thread compensated accumulation + hierarchical reduction
|
||||
q->submit([&](sycl::handler &cgh) {
|
||||
const int nwarps_loc = std::max(1, WG / WARP_SIZE);
|
||||
// store one partial value per warp (xx and xg) for cross-warp reduction
|
||||
auto l_xx = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
|
||||
auto l_xg = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG),
|
||||
sycl::range<3>(1, 1, WG)),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const int row = item_ct1.get_group(2);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
const int64_t i1 = row % n1;
|
||||
const int64_t i2 = (row / n1) % n2;
|
||||
const int64_t i3 = row / (n1 * n2);
|
||||
|
||||
const float *__restrict x_row = x_base + i3 * xs3 + i2 * xs2 + i1 * xs1;
|
||||
const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1;
|
||||
float *__restrict d_row = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1;
|
||||
|
||||
// per-thread accumulation (compensated by default)
|
||||
float sum_xx = 0.f, sum_xg = 0.f;
|
||||
#ifndef GGML_SYCL_RMS_BACK_FAST
|
||||
float c_xx = 0.f, c_xg = 0.f;
|
||||
#endif
|
||||
for (int64_t col = tid; col < D; col += WG) {
|
||||
const float xv = x_row[col];
|
||||
const float gv = g_row[col];
|
||||
#ifdef GGML_SYCL_RMS_BACK_FAST
|
||||
sum_xx += xv * xv;
|
||||
sum_xg += xv * gv;
|
||||
#else
|
||||
float y1 = xv * xv - c_xx;
|
||||
float t1 = sum_xx + y1;
|
||||
c_xx = (t1 - sum_xx) - y1;
|
||||
sum_xx = t1;
|
||||
|
||||
float y2 = xv * gv - c_xg;
|
||||
float t2 = sum_xg + y2;
|
||||
c_xg = (t2 - sum_xg) - y2;
|
||||
sum_xg = t2;
|
||||
#endif
|
||||
}
|
||||
|
||||
// warp-level reduction
|
||||
sycl::float2 xx = sycl::float2(sum_xx,
|
||||
#ifndef GGML_SYCL_RMS_BACK_FAST
|
||||
c_xx
|
||||
#else
|
||||
0.f
|
||||
#endif
|
||||
);
|
||||
sycl::float2 xg = sycl::float2(sum_xg,
|
||||
#ifndef GGML_SYCL_RMS_BACK_FAST
|
||||
c_xg
|
||||
#else
|
||||
0.f
|
||||
#endif
|
||||
);
|
||||
xx = warp_reduce_sum(xx, item_ct1);
|
||||
xg = warp_reduce_sum(xg, item_ct1);
|
||||
|
||||
// cross-warp reduction using local memory (single barrier)
|
||||
const auto sub_group = item_ct1.get_sub_group();
|
||||
const auto sg_id = sub_group.get_group_linear_id();
|
||||
const auto wi_in_sg = sub_group.get_local_linear_id();
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
|
||||
sycl::float2 xx_total = xx;
|
||||
sycl::float2 xg_total = xg;
|
||||
if (nwarps > 1) {
|
||||
if (wi_in_sg == 0) {
|
||||
l_xx[sg_id] = xx;
|
||||
l_xg[sg_id] = xg;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (sg_id == 0) {
|
||||
const unsigned wi_u = wi_in_sg;
|
||||
sycl::float2 xx_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f);
|
||||
sycl::float2 xg_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f);
|
||||
xx_total = warp_reduce_sum(xx_first, item_ct1);
|
||||
xg_total = warp_reduce_sum(xg_first, item_ct1);
|
||||
} else {
|
||||
// other subgroups keep their local totals; they'll be ignored
|
||||
xx_total = xx;
|
||||
xg_total = xg;
|
||||
}
|
||||
// ensure all threads see the first-subgroup result via broadcast below
|
||||
}
|
||||
|
||||
// compute inv_r and coeff once per row and broadcast to the whole work-group
|
||||
float inv_r = 0.f;
|
||||
float coeff = 0.f;
|
||||
if (tid == 0) {
|
||||
const float sum_xx_f = xx_total.x() + xx_total.y();
|
||||
const float sum_xdz_f = xg_total.x() + xg_total.y();
|
||||
const float mean_eps = sum_xx_f / (float) D + eps;
|
||||
const float sum_eps = sum_xx_f + eps * (float) D;
|
||||
inv_r = sycl::rsqrt(mean_eps);
|
||||
coeff = -sum_xdz_f / sum_eps;
|
||||
}
|
||||
inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r);
|
||||
coeff = sycl::group_broadcast(item_ct1.get_group(), coeff);
|
||||
|
||||
for (int64_t col = tid; col < D; col += WG) {
|
||||
d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
|
||||
@@ -19,6 +19,8 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
56
ggml/src/ggml-sycl/repeat_back.cpp
Normal file
56
ggml/src/ggml-sycl/repeat_back.cpp
Normal file
@@ -0,0 +1,56 @@
|
||||
#include "repeat_back.hpp"
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float * src0_dd = (const float *) dst->src[0]->data;
|
||||
float * dst_dd = (float *) dst->data;
|
||||
|
||||
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
|
||||
const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
|
||||
ne03 = dst->src[0]->ne[3];
|
||||
|
||||
const int nr0 = (int) (ne00 / ne0);
|
||||
const int nr1 = (int) (ne01 / ne1);
|
||||
const int nr2 = (int) (ne02 / ne2);
|
||||
const int nr3 = (int) (ne03 / ne3);
|
||||
|
||||
const size_t total = ne0 * ne1 * ne2 * ne3;
|
||||
const int BLOCK_SIZE = 256;
|
||||
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
queue_ptr stream = ctx.stream();
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
const size_t i = item_ct1.get_global_linear_id();
|
||||
if (i >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i0 = i % ne0;
|
||||
const int i1 = (i / ne0) % ne1;
|
||||
const int i2 = (i / (ne0 * ne1)) % ne2;
|
||||
const int i3 = i / (ne0 * ne1 * ne2);
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
for (int j3 = 0; j3 < nr3; ++j3) {
|
||||
for (int j2 = 0; j2 < nr2; ++j2) {
|
||||
for (int j1 = 0; j1 < nr1; ++j1) {
|
||||
for (int j0 = 0; j0 < nr0; ++j0) {
|
||||
acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
|
||||
(i3 + j3 * ne3) * ne00 * ne01 * ne02];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst_dd[i] = acc;
|
||||
});
|
||||
}
|
||||
8
ggml/src/ggml-sycl/repeat_back.hpp
Normal file
8
ggml/src/ggml-sycl/repeat_back.hpp
Normal file
@@ -0,0 +1,8 @@
|
||||
#ifndef GGML_SYCL_REPEAT_BACK_HPP
|
||||
#define GGML_SYCL_REPEAT_BACK_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_REPEAT_BACK_HPP
|
||||
122
ggml/src/ggml-sycl/roll.cpp
Normal file
122
ggml/src/ggml-sycl/roll.cpp
Normal file
@@ -0,0 +1,122 @@
|
||||
#include "roll.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
using namespace sycl;
|
||||
|
||||
static inline int wrap_add(int i, int shift, int n) {
|
||||
|
||||
int s = i + shift;
|
||||
return (s >= n) ? (s - n) : s;
|
||||
}
|
||||
|
||||
static void kernel_roll_fused_i0_i1(
|
||||
queue &q,
|
||||
const float *src_d,
|
||||
float *dst_d,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int sh0, int sh1, int sh2, int sh3)
|
||||
{
|
||||
if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;
|
||||
|
||||
|
||||
const int stride1 = ne0;
|
||||
const int stride2 = ne0 * ne1;
|
||||
const int stride3 = ne0 * ne1 * ne2;
|
||||
|
||||
|
||||
const int shNe0 = (ne0 - sh0) % ne0;
|
||||
const int shNe1 = (ne1 - sh1) % ne1;
|
||||
const int shNe2 = (ne2 - sh2) % ne2;
|
||||
const int shNe3 = (ne3 - sh3) % ne3;
|
||||
|
||||
|
||||
const size_t g0 = (size_t) ne3;
|
||||
const size_t g1 = (size_t) ne2;
|
||||
const size_t g2 = (size_t) (ne1 * ne0);
|
||||
|
||||
const range<3> global{ g0, g1, g2 };
|
||||
|
||||
q.submit([&](handler &h) {
|
||||
h.parallel_for(global, [=](id<3> idx) {
|
||||
const int i3 = (int) idx[0];
|
||||
const int i2 = (int) idx[1];
|
||||
|
||||
const int fused = (int) idx[2];
|
||||
const int i1 = fused / ne0;
|
||||
const int i0 = fused - i1 * ne0; // fused % ne0
|
||||
|
||||
|
||||
const int idx_dst = i0
|
||||
+ i1 * stride1
|
||||
+ i2 * stride2
|
||||
+ i3 * stride3;
|
||||
|
||||
|
||||
const int s0 = wrap_add(i0, shNe0, ne0);
|
||||
const int s1 = wrap_add(i1, shNe1, ne1);
|
||||
const int s2 = wrap_add(i2, shNe2, ne2);
|
||||
const int s3 = wrap_add(i3, shNe3, ne3);
|
||||
|
||||
const int idx_src = s0
|
||||
+ s1 * stride1
|
||||
+ s2 * stride2
|
||||
+ s3 * stride3;
|
||||
|
||||
dst_d[idx_dst] = src_d[idx_src];
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const ggml_tensor *src = dst->src[0];
|
||||
GGML_ASSERT(src && src->type == GGML_TYPE_F32);
|
||||
|
||||
const int ne0 = (int) dst->ne[0];
|
||||
const int ne1 = (int) dst->ne[1];
|
||||
const int ne2 = (int) dst->ne[2];
|
||||
const int ne3 = (int) dst->ne[3];
|
||||
|
||||
const int32_t *params = (const int32_t *) dst->op_params;
|
||||
int shift0 = params[0];
|
||||
int shift1 = params[1];
|
||||
int shift2 = params[2];
|
||||
int shift3 = params[3];
|
||||
|
||||
|
||||
if ((shift0 | shift1 | shift2 | shift3) == 0) {
|
||||
const size_t nb = ggml_nbytes(src);
|
||||
queue *q = ctx.stream();
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
|
||||
return;
|
||||
}
|
||||
|
||||
auto norm = [](int sh, int n) -> int {
|
||||
if (n <= 0) return 0;
|
||||
sh %= n;
|
||||
if (sh < 0) sh += n;
|
||||
return sh;
|
||||
};
|
||||
shift0 = norm(shift0, ne0);
|
||||
shift1 = norm(shift1, ne1);
|
||||
shift2 = norm(shift2, ne2);
|
||||
shift3 = norm(shift3, ne3);
|
||||
|
||||
try {
|
||||
queue *q = ctx.stream();
|
||||
|
||||
const float *src_d = (const float *) src->data;
|
||||
float *dst_d = (float *) dst->data;
|
||||
GGML_ASSERT(src_d && dst_d);
|
||||
|
||||
kernel_roll_fused_i0_i1(
|
||||
*q, src_d, dst_d,
|
||||
ne0, ne1, ne2, ne3,
|
||||
shift0, shift1, shift2, shift3
|
||||
);
|
||||
} catch (const std::exception &e) {
|
||||
std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
|
||||
throw;
|
||||
}
|
||||
}
|
||||
20
ggml/src/ggml-sycl/roll.hpp
Normal file
20
ggml/src/ggml-sycl/roll.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_ROLL_HPP
|
||||
#define GGML_SYCL_ROLL_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
||||
|
||||
#endif // GGML_SYCL_ROLL_HPP
|
||||
@@ -119,7 +119,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
|
||||
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
||||
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
||||
const sycl::nd_item<3> & item_ct1) {
|
||||
const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
|
||||
// get index pos
|
||||
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
||||
if (i0 >= ne0) {
|
||||
@@ -143,17 +143,29 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
|
||||
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < sections.v[0]) {
|
||||
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
||||
if (is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
|
||||
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
||||
} else {
|
||||
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < sections.v[0]) {
|
||||
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
@@ -281,7 +293,7 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
||||
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
||||
const float freq_scale, const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
||||
const mrope_sections sections, queue_ptr stream) {
|
||||
const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||
@@ -297,12 +309,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
||||
if (freq_factors == nullptr) {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
|
||||
});
|
||||
} else {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -381,6 +393,7 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
@@ -422,11 +435,11 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||
if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
|
||||
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, sections, main_stream);
|
||||
freq_factors, sections, is_imrope, main_stream);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
||||
main_stream);
|
||||
is_imrope, main_stream);
|
||||
} else {
|
||||
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
||||
}
|
||||
|
||||
127
ggml/src/ggml-sycl/ssm_conv.cpp
Normal file
127
ggml/src/ggml-sycl/ssm_conv.cpp
Normal file
@@ -0,0 +1,127 @@
|
||||
#include "ssm_conv.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
using namespace sycl;
|
||||
|
||||
static void kernel_ssm_conv(
|
||||
queue &q,
|
||||
const float *src_data,
|
||||
const float *weights,
|
||||
float *dst_data,
|
||||
int d_conv,
|
||||
int d_inner,
|
||||
int n_t,
|
||||
int n_s,
|
||||
int ncs __attribute__((unused)),
|
||||
int src_stride_inner,
|
||||
int src_stride_seq,
|
||||
int dst_stride_token,
|
||||
int dst_stride_seq
|
||||
) {
|
||||
const size_t total_work = static_cast<size_t>(d_inner) * static_cast<size_t>(n_t) * static_cast<size_t>(n_s);
|
||||
const size_t work_group_size = 256;
|
||||
const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size;
|
||||
|
||||
const range<1> global_range(num_work_groups * work_group_size);
|
||||
const range<1> local_range(work_group_size);
|
||||
|
||||
q.submit([&](handler &h) {
|
||||
h.parallel_for(
|
||||
nd_range<1>(global_range, local_range),
|
||||
[=](nd_item<1> item) {
|
||||
const size_t idx = item.get_global_id(0);
|
||||
if (idx >= total_work) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = static_cast<int>(idx % d_inner);
|
||||
const int token = static_cast<int>((idx / d_inner) % n_t);
|
||||
const int seq = static_cast<int>(idx / (static_cast<size_t>(d_inner) * static_cast<size_t>(n_t)));
|
||||
|
||||
const float *s = src_data
|
||||
+ static_cast<size_t>(seq) * static_cast<size_t>(src_stride_seq)
|
||||
+ static_cast<size_t>(channel) * static_cast<size_t>(src_stride_inner)
|
||||
+ static_cast<size_t>(token);
|
||||
|
||||
const float *c = weights + static_cast<size_t>(channel) * static_cast<size_t>(d_conv);
|
||||
|
||||
float sumf = 0.0f;
|
||||
for (int i0 = 0; i0 < d_conv; ++i0) {
|
||||
sumf += s[i0] * c[i0];
|
||||
}
|
||||
|
||||
const size_t dst_idx =
|
||||
static_cast<size_t>(seq) * static_cast<size_t>(dst_stride_seq) +
|
||||
static_cast<size_t>(token) * static_cast<size_t>(dst_stride_token) +
|
||||
static_cast<size_t>(channel);
|
||||
|
||||
dst_data[dst_idx] = sumf;
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int d_conv = src1->ne[0];
|
||||
const int ncs = src0->ne[0];
|
||||
const int d_inner = src0->ne[1];
|
||||
const int n_t = dst->ne[1];
|
||||
const int n_s = dst->ne[2];
|
||||
|
||||
GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t);
|
||||
GGML_ASSERT(src0->ne[1] == d_inner);
|
||||
GGML_ASSERT(src1->ne[1] == d_inner);
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == d_inner);
|
||||
GGML_ASSERT(dst->ne[1] == n_t);
|
||||
GGML_ASSERT(dst->ne[2] == n_s);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0] * static_cast<int>(sizeof(float)));
|
||||
|
||||
const int src_stride_inner = ncs;
|
||||
const int src_stride_seq = ncs * d_inner;
|
||||
const int dst_stride_token = d_inner;
|
||||
const int dst_stride_seq = d_inner * n_t;
|
||||
|
||||
try {
|
||||
queue *q = ctx.stream();
|
||||
|
||||
const float *src_data = static_cast<const float *>(src0->data);
|
||||
const float *weights = static_cast<const float *>(src1->data);
|
||||
float *dst_data = static_cast<float *>(dst->data);
|
||||
|
||||
GGML_ASSERT(src_data && weights && dst_data);
|
||||
|
||||
kernel_ssm_conv(
|
||||
*q,
|
||||
src_data,
|
||||
weights,
|
||||
dst_data,
|
||||
d_conv,
|
||||
d_inner,
|
||||
n_t,
|
||||
n_s,
|
||||
ncs,
|
||||
src_stride_inner,
|
||||
src_stride_seq,
|
||||
dst_stride_token,
|
||||
dst_stride_seq
|
||||
);
|
||||
|
||||
} catch (const std::exception &e) {
|
||||
std::fprintf(stderr, "[SYCL-SSM_CONV] ERROR: %s\n", e.what());
|
||||
throw;
|
||||
}
|
||||
}
|
||||
5
ggml/src/ggml-sycl/ssm_conv.hpp
Normal file
5
ggml/src/ggml-sycl/ssm_conv.hpp
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
uint nrows;
|
||||
uint order;
|
||||
} p;
|
||||
|
||||
@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
|
||||
dst_row[idx1] = tmp;
|
||||
}
|
||||
|
||||
void argsort(bool needs_bounds_check) {
|
||||
void argsort(bool needs_bounds_check, const uint row) {
|
||||
// bitonic sort
|
||||
const int col = int(gl_LocalInvocationID.x);
|
||||
const uint row = gl_WorkGroupID.y;
|
||||
|
||||
const uint row_offset = row * p.ncols;
|
||||
|
||||
@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
|
||||
|
||||
void main() {
|
||||
if (p.ncols == BLOCK_SIZE) {
|
||||
argsort(false);
|
||||
uint row = gl_WorkGroupID.y;
|
||||
while (row < p.nrows) {
|
||||
argsort(false, row);
|
||||
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
}
|
||||
} else {
|
||||
argsort(true);
|
||||
uint row = gl_WorkGroupID.y;
|
||||
while (row < p.nrows) {
|
||||
argsort(true, row);
|
||||
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
#if defined(DATA_A_MXFP4)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
|
||||
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
vec2 v0 = dequantize(ib, iqs, a_offset);
|
||||
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
|
||||
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
|
||||
const uint scales = data_a[a_offset + ib].scales[scalesi];
|
||||
const vec2 d = vec2(data_a[a_offset + ib].d);
|
||||
const vec2 dm = vec2(data_a[a_offset + ib].dm);
|
||||
|
||||
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||
return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
|
||||
}
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(1, 0);
|
||||
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].d);
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
|
||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].d);
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
|
||||
@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
|
||||
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
|
||||
const f16vec2 d = bl.block.d;
|
||||
const f16vec2 dm = bl.block.dm;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
|
||||
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
|
||||
qs = unpack8(qs)[idx & 1];
|
||||
|
||||
const uint scales = bl.block.scales[scalesi];
|
||||
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
|
||||
float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
||||
uint32_t qs = bl.block.qs[iqs];
|
||||
qs >>= shift;
|
||||
qs &= 0xF;
|
||||
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
|
||||
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -26,7 +26,7 @@ void main() {
|
||||
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
|
||||
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
|
||||
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
|
||||
data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,8 +24,8 @@ void main() {
|
||||
const uint ql_idx = 32 * ip + il;
|
||||
const uint8_t qs = data_a[i].qs[32 * ip + il];
|
||||
|
||||
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
|
||||
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
|
||||
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
|
||||
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
|
||||
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
|
||||
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
|
||||
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
|
||||
|
||||
@@ -20,8 +20,8 @@ void main() {
|
||||
const uint is = 2 * il;
|
||||
const uint n = 4;
|
||||
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
|
||||
|
||||
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
|
||||
const uint qs_idx = 32*il + n * ir;
|
||||
|
||||
@@ -19,8 +19,8 @@ void main() {
|
||||
const uint ir = tid % 16;
|
||||
const uint is = 2 * il;
|
||||
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
|
||||
|
||||
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
|
||||
const uint qs_idx = 32*il + 2 * ir;
|
||||
|
||||
@@ -28,8 +28,11 @@ layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
layout (binding = 4) readonly buffer IDS {int data_ids[];};
|
||||
#endif
|
||||
|
||||
#include "dequant_funcs.glsl"
|
||||
@@ -45,6 +48,8 @@ layout (push_constant) uniform parameter
|
||||
uint batch_stride_b;
|
||||
uint batch_stride_d;
|
||||
|
||||
uint enable_bias;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint nei0;
|
||||
uint ne11;
|
||||
@@ -56,6 +61,10 @@ layout (push_constant) uniform parameter
|
||||
#endif
|
||||
} p;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint expert_id;
|
||||
#endif
|
||||
|
||||
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
@@ -75,7 +84,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
||||
batch_idx_a = i03 * p.ne02 + i02;
|
||||
}
|
||||
#else
|
||||
const uint expert_id = data_ids[expert_idx];
|
||||
expert_id = data_ids[expert_idx];
|
||||
#endif
|
||||
|
||||
a_offset =
|
||||
@@ -113,6 +122,13 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
if (p.enable_bias != 0) {
|
||||
#ifdef MUL_MAT_ID
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
|
||||
#else
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
}
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
}
|
||||
}
|
||||
@@ -148,6 +164,13 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
temp[j][n] += tmpsh[j][n][s];
|
||||
}
|
||||
if (p.enable_bias != 0) {
|
||||
#ifdef MUL_MAT_ID
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
|
||||
#else
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
}
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
}
|
||||
}
|
||||
@@ -173,6 +196,13 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
if (p.enable_bias != 0) {
|
||||
#ifdef MUL_MAT_ID
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
|
||||
#else
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
}
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
|
||||
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ncols_x;
|
||||
@@ -29,6 +31,7 @@ layout (push_constant) uniform parameter
|
||||
uint nb03;
|
||||
uint nb13;
|
||||
uint nb23;
|
||||
uint enable_bias;
|
||||
} p;
|
||||
|
||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||
@@ -117,6 +120,9 @@ void main() {
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
if (p.enable_bias != 0) {
|
||||
tmp[0] += FLOAT_TYPE(data_bias[idst]);
|
||||
}
|
||||
dst[idst] = tmp[0];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
|
||||
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
|
||||
|
||||
layout(constant_id = 0) const int BLOCK_SIZE = 32;
|
||||
// gqa_ratio is in the range [1,8]
|
||||
layout(constant_id = 1) const uint gqa_ratio = 1;
|
||||
@@ -29,6 +31,7 @@ layout (push_constant) uniform parameter
|
||||
uint nchannels_y;
|
||||
uint b_offset;
|
||||
uint d_offset;
|
||||
uint enable_bias;
|
||||
} p;
|
||||
|
||||
#if !USE_SUBGROUP_ADD
|
||||
@@ -148,6 +151,9 @@ void main() {
|
||||
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
|
||||
// dst is not transposed and not permuted
|
||||
const uint idst = (channel + c)*nrows_dst + row_dst;
|
||||
if (p.enable_bias != 0) {
|
||||
temp[c] += FLOAT_TYPE(data_bias[idst]);
|
||||
}
|
||||
dst[idst] = temp[c];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
||||
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
|
||||
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
|
||||
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
|
||||
}
|
||||
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
|
||||
temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||
|
||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
|
||||
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
|
||||
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
|
||||
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
||||
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||
|
||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
|
||||
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
|
||||
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
|
||||
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
||||
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[BN];
|
||||
uint _ne1;
|
||||
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
shared uvec4 ballots_sh[NUM_WARPS];
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
#include "mul_mm_id_funcs.glsl"
|
||||
#include "mul_mm_funcs.glsl"
|
||||
|
||||
void main() {
|
||||
|
||||
@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
|
||||
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
|
||||
const uint scalesi = iqs / 8; // 0..15
|
||||
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
|
||||
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
|
||||
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
|
||||
const uint scales = data_a[ib].scales[scalesi];
|
||||
const vec2 d = vec2(data_a[ib].d);
|
||||
const vec2 dm = vec2(data_a[ib].dm);
|
||||
|
||||
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
const vec2 loadd = vec2(data_a[ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
const vec2 loadd = vec2(data_a[ib].dm);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = (idx & 0x07) * 2;
|
||||
|
||||
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||
const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||
|
||||
|
||||
70
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
Normal file
70
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
Normal file
@@ -0,0 +1,70 @@
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[BN];
|
||||
uint _ne1;
|
||||
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
shared uvec4 ballots_sh[NUM_WARPS];
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
||||
#endif // MUL_MAT_ID
|
||||
@@ -10,10 +10,9 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#if defined(MUL_MAT_ID_USE_SUBGROUPS)
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -24,7 +23,10 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||
#endif
|
||||
@@ -76,40 +78,31 @@ layout (constant_id = 10) const uint WARP = 32;
|
||||
|
||||
#define BK 32
|
||||
|
||||
#ifdef COOPMAT
|
||||
#define SHMEM_STRIDE (BK / 4 + 4)
|
||||
#else
|
||||
#define SHMEM_STRIDE (BK / 4 + 1)
|
||||
#endif
|
||||
#define MMQ_SHMEM
|
||||
|
||||
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
|
||||
|
||||
#ifndef COOPMAT
|
||||
#if QUANT_AUXF == 1
|
||||
shared FLOAT_TYPE buf_a_dm[BM];
|
||||
#else
|
||||
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
|
||||
#endif
|
||||
#endif
|
||||
|
||||
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
|
||||
#ifndef COOPMAT
|
||||
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
||||
#endif
|
||||
|
||||
#define LOAD_VEC_A (4 * QUANT_R)
|
||||
#define LOAD_VEC_B 16
|
||||
#include "mul_mmq_shmem_types.glsl"
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
#endif // MUL_MAT_ID
|
||||
#define BK_STEP 1
|
||||
#else
|
||||
#ifndef BK_STEP
|
||||
#define BK_STEP 4
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Shared memory cache
|
||||
shared block_a_cache buf_a[BM * BK_STEP];
|
||||
shared block_b_cache buf_b[BN * BK_STEP];
|
||||
// Register cache
|
||||
block_a_cache cache_a[WMITER * TM];
|
||||
block_b_cache cache_b;
|
||||
|
||||
#define LOAD_VEC_A (4 * QUANT_R_MMQ)
|
||||
#define LOAD_VEC_B 16
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
#include "mul_mm_id_funcs.glsl"
|
||||
#include "mul_mmq_funcs.glsl"
|
||||
|
||||
void main() {
|
||||
@@ -139,26 +132,12 @@ void main() {
|
||||
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const uint WSUBM = WM / WMITER;
|
||||
const uint WSUBN = WN / WNITER;
|
||||
|
||||
#ifdef COOPMAT
|
||||
const uint warp_i = gl_SubgroupID;
|
||||
|
||||
const uint tiw = gl_SubgroupInvocationID;
|
||||
|
||||
const uint cms_per_row = WM / TM;
|
||||
const uint cms_per_col = WN / TN;
|
||||
|
||||
const uint storestride = WARP / TM;
|
||||
const uint store_r = tiw % TM;
|
||||
const uint store_c = tiw / TM;
|
||||
#else
|
||||
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
||||
|
||||
const uint tiw = gl_LocalInvocationID.x % WARP;
|
||||
|
||||
const uint tiwr = tiw % (WSUBM / TM);
|
||||
const uint tiwc = tiw / (WSUBM / TM);
|
||||
#endif
|
||||
|
||||
const uint warp_r = warp_i % (BM / WM);
|
||||
const uint warp_c = warp_i / (BM / WM);
|
||||
@@ -172,17 +151,27 @@ void main() {
|
||||
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint _ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
#else
|
||||
_ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
row_ids[_ne1] = u16vec2(ii0, ii1);
|
||||
if (_ne1 >= ic * BN) {
|
||||
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
#endif
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
@@ -209,159 +198,70 @@ void main() {
|
||||
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
||||
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
}
|
||||
#else
|
||||
int32_t cache_a_qs[WMITER * TM * BK / 4];
|
||||
|
||||
int32_t cache_b_qs[TN * BK / 4];
|
||||
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
FLOAT_TYPE cache_a_dm[WMITER * TM];
|
||||
#else
|
||||
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
|
||||
#endif
|
||||
|
||||
FLOAT_TYPE_VEC2 cache_b_ds[TN];
|
||||
|
||||
for (uint block = start_k; block < end_k; block += BK) {
|
||||
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
|
||||
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
|
||||
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
|
||||
const uint iqs = loadr_a;
|
||||
const uint buf_ib = loadc_a + l;
|
||||
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
|
||||
const uint iqs = loadr_a;
|
||||
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
buf_a_dm[buf_ib] = get_d(ib);
|
||||
#else
|
||||
buf_a_dm[buf_ib] = get_dm(ib);
|
||||
#endif
|
||||
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
|
||||
}
|
||||
#if QUANT_R == 1
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
|
||||
#else
|
||||
const i32vec2 vals = repack(ib, iqs);
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x7;
|
||||
#else
|
||||
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
|
||||
const uint ib_outer = ib / 4;
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
const uint iqs = loadr_b;
|
||||
#endif
|
||||
|
||||
const uint buf_ib = loadc_b + l;
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[buf_ib];
|
||||
const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
|
||||
#else
|
||||
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
|
||||
#endif
|
||||
const uint iqs = loadr_b;
|
||||
|
||||
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
|
||||
}
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a_ib += 1;
|
||||
pos_b_ib += 1;
|
||||
pos_a_ib += BK_STEP;
|
||||
pos_b_ib += BK_STEP;
|
||||
|
||||
#ifdef COOPMAT
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
const uint ib_a = warp_r * WM + cm_row * TM;
|
||||
for (uint k_step = 0; k_step < BK_STEP; k_step++) {
|
||||
// Load from shared into cache
|
||||
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
// TODO: only cache values that are actually needed
|
||||
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
|
||||
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const uint ib_b = warp_c * WN + cm_col * TN;
|
||||
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
// TODO: only cache values that are actually needed
|
||||
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
|
||||
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
|
||||
}
|
||||
|
||||
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
|
||||
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
|
||||
}
|
||||
|
||||
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
|
||||
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
|
||||
cache_b_ds[cc] = buf_b_ds[ib];
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint cache_a_idx = wsir * TM + cr;
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
|
||||
cache_b_qs[cc * (BK / 4) + idx_k]);
|
||||
}
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint reg_ib = wsir * TM + cr;
|
||||
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
|
||||
|
||||
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
|
||||
block_a_to_registers(reg_ib, k_step * BM + buf_ib);
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
|
||||
block_b_to_registers(ib);
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint cache_a_idx = wsir * TM + cr;
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
|
||||
sums[sums_idx] += mmq_dot_product(cache_a_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
barrier();
|
||||
}
|
||||
@@ -373,54 +273,6 @@ void main() {
|
||||
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#ifdef MUL_MAT_ID
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
|
||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
||||
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
||||
|
||||
if (is_aligned && is_in_bounds) {
|
||||
// Full coopMat is within bounds and stride_d is aligned with 16B
|
||||
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
||||
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
} else if (is_in_bounds) {
|
||||
// Full coopMat is within bounds, but stride_d is not aligned
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
||||
// Partial coopMat is within bounds
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
#else
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
@@ -431,19 +283,21 @@ void main() {
|
||||
const uint row_i = dc_warp + cc;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
|
||||
#ifdef MUL_MAT_ID
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
if (dr_warp + cr < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
#else
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // COOPMAT
|
||||
}
|
||||
|
||||
@@ -6,41 +6,89 @@
|
||||
|
||||
// Each iqs value maps to a 32-bit integer
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
// 2-byte loads for Q4_0 blocks (18 bytes)
|
||||
// 4-byte loads for Q4_1 blocks (20 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
|
||||
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]);
|
||||
#ifdef DATA_A_Q4_0
|
||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
#else // DATA_A_Q4_1
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef DATA_A_Q4_0
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
#else // DATA_A_Q4_1
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0)
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
#ifdef DATA_A_Q4_0
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
}
|
||||
#else // DATA_A_Q4_1
|
||||
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
const uint32_t vui = cache_a[ib_a].qs[iqs];
|
||||
const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
|
||||
const int32_t qs_b0 = cache_b.qs[iqs];
|
||||
const int32_t qs_b1 = cache_b.qs[iqs + 4];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
|
||||
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
|
||||
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
// 2-byte loads for Q5_0 blocks (22 bytes)
|
||||
// 4-byte loads for Q5_1 blocks (24 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
|
||||
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]);
|
||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
|
||||
#ifdef DATA_A_Q5_0
|
||||
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
|
||||
#else // DATA_A_Q5_1
|
||||
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
||||
#endif
|
||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
|
||||
@@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) {
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
#ifdef DATA_A_Q5_0
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_1)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
|
||||
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
#else // DATA_A_Q5_1
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
#ifdef DATA_A_Q5_0
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
|
||||
}
|
||||
#else // DATA_A_Q5_1
|
||||
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
buf_a[buf_ib].qh = data_a_packed32[ib].qh;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
cache_a[reg_ib].qh = buf_a[buf_ib].qh;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
const uint32_t vui = cache_a[ib_a].qs[iqs];
|
||||
const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
|
||||
const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||
|
||||
const int32_t qs_b0 = cache_b.qs[iqs];
|
||||
const int32_t qs_b1 = cache_b.qs[iqs + 4];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
|
||||
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
// 2-byte loads for Q8_0 blocks (34 bytes)
|
||||
int32_t repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
|
||||
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]));
|
||||
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
const int32_t qs_b = cache_b.qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, qs_b);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
// 1-byte loads for mxfp4 blocks (17 bytes)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||
data_a[ib].qs[iqs * 4 + 1],
|
||||
data_a[ib].qs[iqs * 4 + 2],
|
||||
data_a[ib].qs[iqs * 4 + 3]));
|
||||
|
||||
return i32vec2( quants & 0x0F0F0F0F,
|
||||
(quants >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * dsb.x * float(q_sum));
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||
data_a[ib].qs[iqs * 4 + 1],
|
||||
data_a[ib].qs[iqs * 4 + 2],
|
||||
data_a[ib].qs[iqs * 4 + 3]));
|
||||
|
||||
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||
|
||||
buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
|
||||
buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].d = buf_a[buf_ib].d;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
|
||||
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
|
||||
#if defined(DATA_A_Q2_K)
|
||||
// 4-byte loads for Q2_K blocks (84 bytes)
|
||||
int32_t repack(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
|
||||
}
|
||||
|
||||
uint8_t get_scale(uint ib, uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
return data_a[ib_k].scales[iqs_k / 4];
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
// Repack 4x4 quants into one int
|
||||
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
|
||||
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
|
||||
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
|
||||
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
|
||||
|
||||
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t sum_d = 0;
|
||||
int32_t sum_m = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
|
||||
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
|
||||
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
|
||||
|
||||
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
|
||||
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q3_K)
|
||||
// 2-byte loads for Q3_K blocks (110 bytes)
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint hm_idx = iqs * QUANT_R_MMQ;
|
||||
const uint iqs_k = (ib % 8) * 8 + hm_idx;
|
||||
|
||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||
const uint hm_shift = iqs_k / 8;
|
||||
|
||||
// Repack 2x4 quants into one int
|
||||
// Add the 3rd bit instead of subtracting it to allow packing the quants
|
||||
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
|
||||
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
|
||||
|
||||
if (iqs == 0) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
|
||||
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
float result = 0.0;
|
||||
int32_t q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
// Subtract 4 from the quants to correct the 3rd bit offset
|
||||
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
|
||||
q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||
|
||||
return ACC_TYPE(cache_b.ds.x * result);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
|
||||
}
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||
|
||||
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
|
||||
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
|
||||
|
||||
// Repack 2x4 quants into one int
|
||||
#if defined(DATA_A_Q4_K)
|
||||
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
|
||||
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
|
||||
|
||||
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
|
||||
#else // defined(DATA_A_Q5_K)
|
||||
const uint qh_idx = iqs * QUANT_R_MMQ;
|
||||
const uint qh_shift = iqs_k / 8;
|
||||
|
||||
buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
|
||||
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
|
||||
#endif
|
||||
|
||||
|
||||
if (iqs == 0) {
|
||||
// Scale index
|
||||
const uint is = iqs_k / 8;
|
||||
u8vec2 scale_dm;
|
||||
if (is < 4) {
|
||||
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
|
||||
} else {
|
||||
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
|
||||
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
||||
}
|
||||
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
int32_t q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
#if defined(DATA_A_Q4_K)
|
||||
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
|
||||
#else // defined(DATA_A_Q5_K)
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
#endif
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
|
||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_outer = ib / 4;
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
}
|
||||
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
|
||||
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
|
||||
}
|
||||
|
||||
void block_b_to_registers(const uint ib) {
|
||||
cache_b.ds = buf_b[ib].ds;
|
||||
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
|
||||
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q6_K)
|
||||
// 2-byte loads for Q6_K blocks (210 bytes)
|
||||
#ifdef MMQ_SHMEM
|
||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint ib_k = ib / 8;
|
||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||
|
||||
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
|
||||
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
|
||||
|
||||
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
|
||||
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
|
||||
|
||||
if (iqs == 0) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
|
||||
}
|
||||
}
|
||||
|
||||
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
|
||||
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
|
||||
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||
float result = 0.0;
|
||||
int32_t q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
|
||||
q_sum = 0;
|
||||
|
||||
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
|
||||
const int32_t qs_a = cache_a[ib_a].qs[iqs];
|
||||
|
||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||
}
|
||||
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||
|
||||
return ACC_TYPE(cache_b.ds.x * result);
|
||||
}
|
||||
#endif // MMQ_SHMEM
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
@@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
const uint ib_k = ib / 8;
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
78
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
Normal file
78
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
Normal file
@@ -0,0 +1,78 @@
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
uint32_t qh;
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
uint32_t qh;
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
#define QUANT_R_MMQ 1
|
||||
// AMD likes 4, Intel likes 1 and Nvidia likes 2
|
||||
// #define BK_STEP 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[32/4];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE d;
|
||||
};
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
#define QUANT_R_MMQ 4
|
||||
struct block_a_cache {
|
||||
uint32_t qs[2];
|
||||
u8vec2 scales;
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[4];
|
||||
FLOAT_TYPE_VEC2 d_scales;
|
||||
};
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[4];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
#define QUANT_R_MMQ 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define QUANT_R_MMQ 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 d_scales;
|
||||
};
|
||||
#endif
|
||||
|
||||
struct block_b_cache
|
||||
{
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 ds;
|
||||
};
|
||||
@@ -23,16 +23,100 @@ layout (push_constant) uniform parameter2
|
||||
uint rms_partials;
|
||||
} p;
|
||||
|
||||
// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
|
||||
// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
|
||||
// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
|
||||
layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
|
||||
layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
|
||||
|
||||
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
|
||||
// No readonly/writeonly decorations. Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
|
||||
layout (binding = 0) buffer A0 {A_TYPE data_a[];} a0;
|
||||
layout (binding = 1) buffer A1 {A_TYPE data_a[];} a1;
|
||||
layout (binding = 2) buffer A2 {A_TYPE data_a[];} a2;
|
||||
layout (binding = 3) buffer A3 {A_TYPE data_a[];} a3;
|
||||
layout (binding = 4) buffer A4 {A_TYPE data_a[];} a4;
|
||||
layout (binding = 5) buffer A5 {A_TYPE data_a[];} a5;
|
||||
layout (binding = 6) buffer A6 {A_TYPE data_a[];} a6;
|
||||
layout (binding = 7) buffer A7 {A_TYPE data_a[];} a7;
|
||||
layout (binding = 8) buffer A8 {A_TYPE data_a[];} a8;
|
||||
layout (binding = 9) buffer A9 {A_TYPE data_a[];} a9;
|
||||
layout (binding = 10) buffer A10 {A_TYPE data_a[];} a10;
|
||||
layout (binding = 11) buffer A11 {A_TYPE data_a[];} a11;
|
||||
layout (binding = 0) buffer D0 {D_TYPE data_d[];} d0;
|
||||
layout (binding = 1) buffer D1 {D_TYPE data_d[];} d1;
|
||||
layout (binding = 2) buffer D2 {D_TYPE data_d[];} d2;
|
||||
layout (binding = 3) buffer D3 {D_TYPE data_d[];} d3;
|
||||
layout (binding = 4) buffer D4 {D_TYPE data_d[];} d4;
|
||||
layout (binding = 5) buffer D5 {D_TYPE data_d[];} d5;
|
||||
layout (binding = 6) buffer D6 {D_TYPE data_d[];} d6;
|
||||
layout (binding = 7) buffer D7 {D_TYPE data_d[];} d7;
|
||||
layout (binding = 8) buffer D8 {D_TYPE data_d[];} d8;
|
||||
layout (binding = 9) buffer D9 {D_TYPE data_d[];} d9;
|
||||
layout (binding = 10) buffer D10 {D_TYPE data_d[];} d10;
|
||||
layout (binding = 11) buffer D11 {D_TYPE data_d[];} d11;
|
||||
layout (binding = 0, std430) buffer PartialBuf0 {float partial_sums[];} partials0;
|
||||
layout (binding = 1, std430) buffer PartialBuf1 {float partial_sums[];} partials1;
|
||||
layout (binding = 2, std430) buffer PartialBuf2 {float partial_sums[];} partials2;
|
||||
layout (binding = 3, std430) buffer PartialBuf3 {float partial_sums[];} partials3;
|
||||
layout (binding = 4, std430) buffer PartialBuf4 {float partial_sums[];} partials4;
|
||||
layout (binding = 5, std430) buffer PartialBuf5 {float partial_sums[];} partials5;
|
||||
layout (binding = 6, std430) buffer PartialBuf6 {float partial_sums[];} partials6;
|
||||
layout (binding = 7, std430) buffer PartialBuf7 {float partial_sums[];} partials7;
|
||||
layout (binding = 8, std430) buffer PartialBuf8 {float partial_sums[];} partials8;
|
||||
layout (binding = 9, std430) buffer PartialBuf9 {float partial_sums[];} partials9;
|
||||
layout (binding = 10, std430) buffer PartialBuf10 {float partial_sums[];} partials10;
|
||||
layout (binding = 11, std430) buffer PartialBuf11 {float partial_sums[];} partials11;
|
||||
|
||||
layout(constant_id = 0) const uint num_srcs = 2;
|
||||
|
||||
FLOAT_TYPE load_a(uint b, uint i) {
|
||||
switch (b) {
|
||||
case 0: return FLOAT_TYPE(a0.data_a[i]);
|
||||
case 1: return FLOAT_TYPE(a1.data_a[i]);
|
||||
case 2: return FLOAT_TYPE(a2.data_a[i]);
|
||||
case 3: return FLOAT_TYPE(a3.data_a[i]);
|
||||
case 4: return FLOAT_TYPE(a4.data_a[i]);
|
||||
case 5: return FLOAT_TYPE(a5.data_a[i]);
|
||||
case 6: return FLOAT_TYPE(a6.data_a[i]);
|
||||
case 7: return FLOAT_TYPE(a7.data_a[i]);
|
||||
case 8: return FLOAT_TYPE(a8.data_a[i]);
|
||||
case 9: return FLOAT_TYPE(a9.data_a[i]);
|
||||
case 10: return FLOAT_TYPE(a10.data_a[i]);
|
||||
case 11: return FLOAT_TYPE(a11.data_a[i]);
|
||||
default: return FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
|
||||
void store_d(uint b, uint i, FLOAT_TYPE v) {
|
||||
switch (b) {
|
||||
case 0: d0.data_d[i] = D_TYPE(v); break;
|
||||
case 1: d1.data_d[i] = D_TYPE(v); break;
|
||||
case 2: d2.data_d[i] = D_TYPE(v); break;
|
||||
case 3: d3.data_d[i] = D_TYPE(v); break;
|
||||
case 4: d4.data_d[i] = D_TYPE(v); break;
|
||||
case 5: d5.data_d[i] = D_TYPE(v); break;
|
||||
case 6: d6.data_d[i] = D_TYPE(v); break;
|
||||
case 7: d7.data_d[i] = D_TYPE(v); break;
|
||||
case 8: d8.data_d[i] = D_TYPE(v); break;
|
||||
case 9: d9.data_d[i] = D_TYPE(v); break;
|
||||
case 10: d10.data_d[i] = D_TYPE(v); break;
|
||||
case 11: d11.data_d[i] = D_TYPE(v); break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
void store_partial(uint b, uint i, float v) {
|
||||
switch (b) {
|
||||
case 0: partials0.partial_sums[i] = v; break;
|
||||
case 1: partials1.partial_sums[i] = v; break;
|
||||
case 2: partials2.partial_sums[i] = v; break;
|
||||
case 3: partials3.partial_sums[i] = v; break;
|
||||
case 4: partials4.partial_sums[i] = v; break;
|
||||
case 5: partials5.partial_sums[i] = v; break;
|
||||
case 6: partials6.partial_sums[i] = v; break;
|
||||
case 7: partials7.partial_sums[i] = v; break;
|
||||
case 8: partials8.partial_sums[i] = v; break;
|
||||
case 9: partials9.partial_sums[i] = v; break;
|
||||
case 10: partials10.partial_sums[i] = v; break;
|
||||
case 11: partials11.partial_sums[i] = v; break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
|
||||
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
|
||||
}
|
||||
@@ -78,10 +162,10 @@ void main() {
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0);
|
||||
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
|
||||
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
|
||||
sum += load_a(s, src_idx(s, i00, i01, i02, i03));
|
||||
}
|
||||
sum_sq += sum*sum;
|
||||
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
||||
store_d(num_srcs, dst_idx(i00, i01, i02, i03), sum);
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
@@ -104,7 +188,7 @@ void main() {
|
||||
}
|
||||
|
||||
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
||||
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
|
||||
store_partial(num_srcs + 1, orig_idx / (num_iter * num_threads), sum_sq);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -10,6 +10,7 @@ layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {int data_pos[];};
|
||||
layout (binding = 2) readonly buffer Z {float data_ff[];};
|
||||
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
||||
layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
@@ -26,7 +27,9 @@ layout (push_constant) uniform parameter {
|
||||
uint s1;
|
||||
uint s2;
|
||||
int sections[4];
|
||||
uint is_imrope;
|
||||
uint is_back;
|
||||
uint set_rows_stride;
|
||||
} p;
|
||||
|
||||
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||
|
||||
@@ -32,17 +32,29 @@ void main() {
|
||||
const uint sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < p.sections[0]) {
|
||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
if (p.is_imrope != 0) {
|
||||
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
|
||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
||||
} else {
|
||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < p.sections[0]) {
|
||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
@@ -16,12 +16,19 @@ void main() {
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = row_x*ne0 + i0/2;
|
||||
idst += data_i[channel_x].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
||||
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
||||
data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
|
||||
data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -16,12 +16,19 @@ void main() {
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0;
|
||||
uint idst = row_dst*ne0 + i0;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = row_x*ne0 + i0;
|
||||
idst += data_i[channel_x].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + 0] = data_a[ix + 0];
|
||||
data_d[idst + 1] = data_a[ix + 1];
|
||||
data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
|
||||
data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
uint n_expert_used;
|
||||
float clamp_min;
|
||||
float clamp_max;
|
||||
};
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint n_experts = 512;
|
||||
layout(constant_id = 2) const bool with_norm = true;
|
||||
layout(constant_id = 3) const bool late_softmax = false;
|
||||
|
||||
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||
|
||||
@@ -25,6 +28,52 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
||||
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
||||
|
||||
const float INFINITY = 1.0 / 0.0;
|
||||
|
||||
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||
void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
|
||||
float max_val = -INFINITY;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const uint idx = lane + i * WARP_SIZE;
|
||||
const bool is_active = !use_limit || (idx < limit);
|
||||
if (is_active) {
|
||||
max_val = max(max_val, vals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
max_val = subgroupMax(max_val);
|
||||
|
||||
float sum = 0.f;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const uint idx = lane + i * WARP_SIZE;
|
||||
const bool is_active = !use_limit || (idx < limit);
|
||||
if (is_active) {
|
||||
const float val = exp(vals[i] - max_val);
|
||||
vals[i] = val;
|
||||
sum += val;
|
||||
} else {
|
||||
vals[i] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
sum = subgroupAdd(sum);
|
||||
|
||||
const float inv_sum = 1.0f / sum;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const uint idx = lane + i * WARP_SIZE;
|
||||
const bool is_active = !use_limit || (idx < limit);
|
||||
if (is_active) {
|
||||
vals[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||
if (row >= n_rows) {
|
||||
@@ -35,43 +84,16 @@ void main() {
|
||||
const uint weights_offset = n_expert_used * row;
|
||||
const uint ids_offset = n_experts * row;
|
||||
|
||||
float logits_r[experts_per_thread];
|
||||
|
||||
const float INFINITY = 1.0 / 0.0;
|
||||
float wt[experts_per_thread];
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + gl_LocalInvocationID.x;
|
||||
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
|
||||
const uint expert = i + gl_LocalInvocationID.x;
|
||||
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
|
||||
}
|
||||
|
||||
float max_val = logits_r[0];
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
max_val = max(val, max_val);
|
||||
}
|
||||
|
||||
max_val = subgroupMax(max_val);
|
||||
|
||||
float wt[experts_per_thread];
|
||||
float tmp = 0.f;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
wt[i] = exp(val - max_val);
|
||||
tmp += wt[i];
|
||||
}
|
||||
|
||||
tmp = subgroupAdd(tmp);
|
||||
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
wt[i] = wt[i] * inv_sum;
|
||||
if (!late_softmax) {
|
||||
softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
|
||||
}
|
||||
|
||||
// at this point, each thread holds a portion of softmax,
|
||||
@@ -82,6 +104,11 @@ void main() {
|
||||
|
||||
float output_weights[experts_per_thread];
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
output_weights[i] = 0.f;
|
||||
}
|
||||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
uint max_expert = gl_LocalInvocationID.x;
|
||||
@@ -121,6 +148,7 @@ void main() {
|
||||
|
||||
if (with_norm) {
|
||||
wt_sum = subgroupAdd(wt_sum);
|
||||
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
||||
[[unroll]]
|
||||
@@ -129,6 +157,10 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
if (late_softmax) {
|
||||
softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
||||
|
||||
@@ -66,6 +66,7 @@ struct block_q4_0_packed16
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_q4_0
|
||||
#define A_TYPE_PACKED16 block_q4_0_packed16
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q4_1 32
|
||||
@@ -98,6 +99,7 @@ struct block_q4_1_packed32
|
||||
#define A_TYPE block_q4_1
|
||||
#define A_TYPE_PACKED16 block_q4_1_packed16
|
||||
#define A_TYPE_PACKED32 block_q4_1_packed32
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q5_0 32
|
||||
@@ -123,6 +125,7 @@ struct block_q5_0_packed16
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_q5_0
|
||||
#define A_TYPE_PACKED16 block_q5_0_packed16
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q5_1 32
|
||||
@@ -158,6 +161,7 @@ struct block_q5_1_packed32
|
||||
#define A_TYPE block_q5_1
|
||||
#define A_TYPE_PACKED16 block_q5_1_packed16
|
||||
#define A_TYPE_PACKED32 block_q5_1_packed32
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q8_0 32
|
||||
@@ -186,6 +190,7 @@ struct block_q8_0_packed32
|
||||
#define A_TYPE block_q8_0
|
||||
#define A_TYPE_PACKED16 block_q8_0_packed16
|
||||
#define A_TYPE_PACKED32 block_q8_0_packed32
|
||||
#define DATA_A_QUANT_LEGACY
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q8_1 32
|
||||
@@ -226,21 +231,21 @@ struct block_q2_K
|
||||
{
|
||||
uint8_t scales[QUANT_K_Q2_K/16];
|
||||
uint8_t qs[QUANT_K_Q2_K/4];
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
};
|
||||
|
||||
struct block_q2_K_packed16
|
||||
{
|
||||
uint16_t scales[QUANT_K_Q2_K/16/2];
|
||||
uint16_t qs[QUANT_K_Q2_K/4/2];
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
};
|
||||
|
||||
struct block_q2_K_packed32
|
||||
{
|
||||
uint32_t scales[QUANT_K_Q2_K/16/4];
|
||||
uint32_t qs[QUANT_K_Q2_K/4/4];
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
};
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
@@ -249,6 +254,8 @@ struct block_q2_K_packed32
|
||||
#define A_TYPE block_q2_K
|
||||
#define A_TYPE_PACKED16 block_q2_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q2_K_packed32
|
||||
#define SCALES_PER_32 2
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q3_K 256
|
||||
@@ -274,27 +281,28 @@ struct block_q3_K_packed16
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q3_K
|
||||
#define A_TYPE_PACKED16 block_q3_K_packed16
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q4_K 256
|
||||
|
||||
struct block_q4_K
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint8_t scales[3*QUANT_K_Q4_K/64];
|
||||
uint8_t qs[QUANT_K_Q4_K/2];
|
||||
};
|
||||
|
||||
struct block_q4_K_packed16
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint16_t scales[3*QUANT_K_Q4_K/64/2];
|
||||
uint16_t qs[QUANT_K_Q4_K/2/2];
|
||||
};
|
||||
|
||||
struct block_q4_K_packed32
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint32_t scales[3*QUANT_K_Q4_K/64/4];
|
||||
uint32_t qs[QUANT_K_Q4_K/2/4];
|
||||
};
|
||||
@@ -310,13 +318,14 @@ struct block_q4_K_packed128
|
||||
#define A_TYPE block_q4_K
|
||||
#define A_TYPE_PACKED16 block_q4_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q4_K_packed32
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q5_K 256
|
||||
|
||||
struct block_q5_K
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint8_t scales[12];
|
||||
uint8_t qh[QUANT_K_Q5_K/8];
|
||||
uint8_t qs[QUANT_K_Q5_K/2];
|
||||
@@ -324,12 +333,20 @@ struct block_q5_K
|
||||
|
||||
struct block_q5_K_packed16
|
||||
{
|
||||
f16vec2 d;
|
||||
f16vec2 dm;
|
||||
uint16_t scales[12/2];
|
||||
uint16_t qh[QUANT_K_Q5_K/8/2];
|
||||
uint16_t qs[QUANT_K_Q5_K/2/2];
|
||||
};
|
||||
|
||||
struct block_q5_K_packed32
|
||||
{
|
||||
f16vec2 dm;
|
||||
uint32_t scales[12/4];
|
||||
uint32_t qh[QUANT_K_Q5_K/8/4];
|
||||
uint32_t qs[QUANT_K_Q5_K/2/4];
|
||||
};
|
||||
|
||||
struct block_q5_K_packed128
|
||||
{
|
||||
uvec4 q5k[11];
|
||||
@@ -340,6 +357,8 @@ struct block_q5_K_packed128
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q5_K
|
||||
#define A_TYPE_PACKED16 block_q5_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q5_K_packed32
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q6_K 256
|
||||
@@ -356,7 +375,7 @@ struct block_q6_K_packed16
|
||||
{
|
||||
uint16_t ql[QUANT_K_Q6_K/2/2];
|
||||
uint16_t qh[QUANT_K_Q6_K/4/2];
|
||||
int8_t scales[QUANT_K_Q6_K/16];
|
||||
int16_t scales[QUANT_K_Q6_K/16/2];
|
||||
float16_t d;
|
||||
};
|
||||
|
||||
@@ -365,6 +384,7 @@ struct block_q6_K_packed16
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q6_K
|
||||
#define A_TYPE_PACKED16 block_q6_K_packed16
|
||||
#define DATA_A_QUANT_K
|
||||
#endif
|
||||
|
||||
// IQuants
|
||||
@@ -1363,18 +1383,11 @@ struct block_mxfp4
|
||||
uint8_t qs[QUANT_K_MXFP4/2];
|
||||
};
|
||||
|
||||
//struct block_mxfp4_packed16
|
||||
//{
|
||||
// uint8_t e;
|
||||
// uint16_t qs[QUANT_K_MXFP4/2/2];
|
||||
//};
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
#define QUANT_K QUANT_K_MXFP4
|
||||
#define QUANT_R QUANT_R_MXFP4
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_mxfp4
|
||||
//#define A_TYPE_PACKED16 block_mxfp4_packed16
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
|
||||
@@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize)
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
const FLOAT_TYPE kvalues_mxfp4_const[16] = {
|
||||
FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
|
||||
FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
|
||||
const int8_t kvalues_mxfp4_const[16] = {
|
||||
int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
|
||||
int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
|
||||
};
|
||||
|
||||
shared FLOAT_TYPE kvalues_mxfp4[16];
|
||||
shared int8_t kvalues_mxfp4[16];
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
|
||||
@@ -7,6 +7,7 @@ layout (push_constant) uniform parameter
|
||||
uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
float pixel_offset;
|
||||
} p;
|
||||
|
||||
#include "types.glsl"
|
||||
@@ -19,7 +20,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
|
||||
#define NEAREST 0
|
||||
#define BILINEAR 1
|
||||
#define ALIGN_CORNERS (1 << 8)
|
||||
|
||||
layout (constant_id = 0) const uint scale_mode = 0;
|
||||
|
||||
@@ -52,7 +52,7 @@ float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
|
||||
float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
|
||||
const ivec2 ne0 = ivec2(p.ne00, p.ne01);
|
||||
|
||||
const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
|
||||
const vec2 c = (vec2(i10, i11) + p.pixel_offset) / vec2(p.sf0, p.sf1) - p.pixel_offset;
|
||||
const vec2 c0f = floor(c);
|
||||
const vec2 d = c - c0f;
|
||||
const ivec2 c0 = max(ivec2(c0f), 0);
|
||||
@@ -61,16 +61,6 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
|
||||
return fetch_bilinear(c0, c1, d, i12, i13);
|
||||
}
|
||||
|
||||
float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
|
||||
const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
|
||||
const vec2 c0f = floor(c);
|
||||
const vec2 d = c - c0f;
|
||||
const ivec2 c0 = ivec2(c0f);
|
||||
const ivec2 c1 = c0 + 1;
|
||||
|
||||
return fetch_bilinear(c0, c1, d, i12, i13);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
@@ -91,9 +81,6 @@ void main() {
|
||||
case BILINEAR:
|
||||
result = interpolate_bilinear(i10, i11, i12, i13);
|
||||
break;
|
||||
case BILINEAR | ALIGN_CORNERS:
|
||||
result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
|
||||
break;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + idx] = D_TYPE(result);
|
||||
|
||||
@@ -317,7 +317,8 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
|
||||
|
||||
// disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
|
||||
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
|
||||
std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O";
|
||||
// disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
|
||||
std::string opt_level = (coopmat || name.find("bf16") != std::string::npos || name.find("rope") != std::string::npos) ? "" : "-O";
|
||||
|
||||
#ifdef _WIN32
|
||||
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
|
||||
@@ -566,7 +567,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
|
||||
// Integer dot mmq performs better with f32 accumulators
|
||||
if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
|
||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
#endif
|
||||
@@ -574,7 +576,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
}
|
||||
|
||||
void process_shaders() {
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
|
||||
|
||||
// matmul
|
||||
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
|
||||
@@ -841,10 +843,14 @@ void process_shaders() {
|
||||
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
|
||||
@@ -221,6 +221,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
|
||||
let is_neox = bool(params.mode & 2);
|
||||
let is_mrope = bool(params.mode & 8);
|
||||
let is_imrope = params.mode == 40;
|
||||
let is_vision = params.mode == 24;
|
||||
|
||||
var i = gid.x * 2; // start index for this thread
|
||||
@@ -248,24 +249,36 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let sec_w = params.sections1 + params.sections0;
|
||||
let sec_e = params.sections2 + sec_w;
|
||||
let sector = (i0 / 2) % sect_dims;
|
||||
if (sector >= params.sections0 && sector < sec_w) {
|
||||
theta_base_mult = 1;
|
||||
if (is_vision) {
|
||||
theta_scale_pwr = sector - params.sections0;
|
||||
}
|
||||
} else if (sector >= sec_w && sector < sec_e) {
|
||||
theta_base_mult = 2;
|
||||
if (is_vision) {
|
||||
theta_scale_pwr = sector - sec_w;
|
||||
}
|
||||
} else if (sector >= sec_e) {
|
||||
if (is_vision) {
|
||||
theta_scale_pwr = sector - sec_e;
|
||||
theta_scale_pwr = (i0 / 2) % sec_e;
|
||||
}
|
||||
theta_base_mult = 3;
|
||||
} else if (is_vision) {
|
||||
theta_scale_pwr = sector;
|
||||
if (is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * params.sections1) {
|
||||
theta_base_mult = 1;
|
||||
} else if (sector % 3 == 2 && sector < 3 * params.sections2) {
|
||||
theta_base_mult = 2;
|
||||
} else if (sector % 3 == 0 && sector < 3 * params.sections0) {
|
||||
theta_base_mult = 0;
|
||||
} else {
|
||||
theta_base_mult = 3;
|
||||
}
|
||||
} else {
|
||||
if (sector >= params.sections0 && sector < sec_w) {
|
||||
theta_base_mult = 1;
|
||||
if (is_vision) {
|
||||
theta_scale_pwr = sector - params.sections0;
|
||||
}
|
||||
} else if (sector >= sec_w && sector < sec_e) {
|
||||
theta_base_mult = 2;
|
||||
if (is_vision) {
|
||||
theta_scale_pwr = sector - sec_w;
|
||||
}
|
||||
} else if (sector >= sec_e) {
|
||||
if (is_vision) {
|
||||
theta_scale_pwr = sector - sec_e;
|
||||
theta_scale_pwr = (i0 / 2) % sec_e;
|
||||
}
|
||||
theta_base_mult = 3;
|
||||
} else if (is_vision) {
|
||||
theta_scale_pwr = sector;
|
||||
}
|
||||
}
|
||||
}
|
||||
let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
|
||||
|
||||
@@ -111,6 +111,7 @@ class Keys:
|
||||
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
|
||||
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
|
||||
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
|
||||
NUM_DEEPSTACK_LAYERS = "{arch}.n_deepstack_layers"
|
||||
POOLING_TYPE = "{arch}.pooling_type"
|
||||
LOGIT_SCALE = "{arch}.logit_scale"
|
||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||
@@ -277,6 +278,7 @@ class Keys:
|
||||
USE_GELU = "clip.use_gelu"
|
||||
USE_SILU = "clip.use_silu"
|
||||
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
|
||||
IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "clip.vision.attention.head_count"
|
||||
@@ -350,6 +352,8 @@ class MODEL_ARCH(IntEnum):
|
||||
QWEN2VL = auto()
|
||||
QWEN3 = auto()
|
||||
QWEN3MOE = auto()
|
||||
QWEN3VL = auto()
|
||||
QWEN3VLMOE = auto()
|
||||
PHI2 = auto()
|
||||
PHI3 = auto()
|
||||
PHIMOE = auto()
|
||||
@@ -420,6 +424,8 @@ class MODEL_ARCH(IntEnum):
|
||||
SEED_OSS = auto()
|
||||
GROVEMOE = auto()
|
||||
APERTUS = auto()
|
||||
COGVLM = auto()
|
||||
MINIMAXM2 = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
@@ -430,6 +436,8 @@ class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
GLM_EDGE = auto()
|
||||
MERGER = auto()
|
||||
GEMMA3 = auto()
|
||||
QWEN3VL = auto()
|
||||
COGVLM = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
@@ -600,6 +608,11 @@ class MODEL_TENSOR(IntEnum):
|
||||
SHORTCONV_CONV = auto()
|
||||
SHORTCONV_INPROJ = auto()
|
||||
SHORTCONV_OUTPROJ = auto()
|
||||
VISEXP_ATTN_QKV = auto()
|
||||
VISEXP_ATTN_OUT = auto()
|
||||
VISEXP_GATE = auto()
|
||||
VISEXP_DOWN = auto()
|
||||
VISEXP_UP = auto()
|
||||
# vision
|
||||
V_MMPROJ = auto()
|
||||
V_MMPROJ_FC = auto()
|
||||
@@ -609,6 +622,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
V_ENC_EMBD_PATCH = auto()
|
||||
V_ENC_EMBD_POS = auto()
|
||||
V_ENC_INPUT_NORM = auto()
|
||||
V_ENC_ATTN_QKV = auto()
|
||||
V_ENC_ATTN_Q = auto()
|
||||
V_ENC_ATTN_Q_NORM = auto()
|
||||
V_ENC_ATTN_K = auto()
|
||||
@@ -640,6 +654,15 @@ class MODEL_TENSOR(IntEnum):
|
||||
V_RESMPL_QUERY = auto() # minicpmv
|
||||
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
|
||||
V_MM_PATCH_MERGER = auto() # mistral small 3.1
|
||||
V_DS_NORM = auto() # qwen3vl
|
||||
V_DS_FC1 = auto() # qwen3vl
|
||||
V_DS_FC2 = auto() # qwen3vl
|
||||
V_MM_POST_FC_NORM = auto() # cogvlm
|
||||
V_MM_UP = auto() # cogvlm
|
||||
V_MM_DOWN = auto() # cogvlm
|
||||
V_MM_GATE = auto() # cogvlm
|
||||
V_TOK_BOI = auto() # cogvlm
|
||||
V_TOK_EOI = auto() # cogvlm
|
||||
# audio (mtmd)
|
||||
A_ENC_EMBD_POS = auto()
|
||||
A_ENC_CONV1D = auto()
|
||||
@@ -695,6 +718,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||
MODEL_ARCH.QWEN3: "qwen3",
|
||||
MODEL_ARCH.QWEN3MOE: "qwen3moe",
|
||||
MODEL_ARCH.QWEN3VL: "qwen3vl",
|
||||
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
|
||||
MODEL_ARCH.PHI2: "phi2",
|
||||
MODEL_ARCH.PHI3: "phi3",
|
||||
MODEL_ARCH.PHIMOE: "phimoe",
|
||||
@@ -766,6 +791,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.SEED_OSS: "seed_oss",
|
||||
MODEL_ARCH.GROVEMOE: "grovemoe",
|
||||
MODEL_ARCH.APERTUS: "apertus",
|
||||
MODEL_ARCH.MINIMAXM2: "minimax-m2",
|
||||
MODEL_ARCH.COGVLM: "cogvlm",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
@@ -946,6 +973,11 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.SHORTCONV_CONV: "blk.{bid}.shortconv.conv",
|
||||
MODEL_TENSOR.SHORTCONV_INPROJ: "blk.{bid}.shortconv.in_proj",
|
||||
MODEL_TENSOR.SHORTCONV_OUTPROJ: "blk.{bid}.shortconv.out_proj",
|
||||
MODEL_TENSOR.VISEXP_ATTN_QKV: "blk.{bid}.vis_attn_qkv",
|
||||
MODEL_TENSOR.VISEXP_ATTN_OUT: "blk.{bid}.vis_attn_output",
|
||||
MODEL_TENSOR.VISEXP_GATE: "blk.{bid}.vis_gate",
|
||||
MODEL_TENSOR.VISEXP_DOWN: "blk.{bid}.vis_down",
|
||||
MODEL_TENSOR.VISEXP_UP: "blk.{bid}.vis_up",
|
||||
# vision
|
||||
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
|
||||
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
|
||||
@@ -954,6 +986,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
|
||||
MODEL_TENSOR.V_ENC_ATTN_QKV: "v.blk.{bid}.attn_qkv",
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: "v.blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k",
|
||||
@@ -986,6 +1019,15 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
|
||||
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
|
||||
MODEL_TENSOR.V_DS_NORM: "v.deepstack.{bid}.norm",
|
||||
MODEL_TENSOR.V_DS_FC1: "v.deepstack.{bid}.fc1",
|
||||
MODEL_TENSOR.V_DS_FC2: "v.deepstack.{bid}.fc2",
|
||||
MODEL_TENSOR.V_MM_POST_FC_NORM: "mm.post_fc_norm", # cogvlm
|
||||
MODEL_TENSOR.V_MM_UP: "mm.up",
|
||||
MODEL_TENSOR.V_MM_DOWN: "mm.down",
|
||||
MODEL_TENSOR.V_MM_GATE: "mm.gate",
|
||||
MODEL_TENSOR.V_TOK_BOI: "v.boi",
|
||||
MODEL_TENSOR.V_TOK_EOI: "v.eoi",
|
||||
# audio (mtmd)
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
|
||||
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
|
||||
@@ -1023,6 +1065,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH,
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS,
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM,
|
||||
MODEL_TENSOR.V_ENC_ATTN_QKV,
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q,
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q_NORM,
|
||||
MODEL_TENSOR.V_ENC_ATTN_K,
|
||||
@@ -1054,6 +1097,15 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.V_RESMPL_QUERY,
|
||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
|
||||
MODEL_TENSOR.V_MM_PATCH_MERGER,
|
||||
MODEL_TENSOR.V_DS_NORM,
|
||||
MODEL_TENSOR.V_DS_FC1,
|
||||
MODEL_TENSOR.V_DS_FC2,
|
||||
MODEL_TENSOR.V_MM_POST_FC_NORM,
|
||||
MODEL_TENSOR.V_MM_UP,
|
||||
MODEL_TENSOR.V_MM_DOWN,
|
||||
MODEL_TENSOR.V_MM_GATE,
|
||||
MODEL_TENSOR.V_TOK_BOI,
|
||||
MODEL_TENSOR.V_TOK_EOI,
|
||||
# audio
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS,
|
||||
MODEL_TENSOR.A_ENC_CONV1D,
|
||||
@@ -1495,6 +1547,40 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.QWEN3VL: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.QWEN3VLMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.PLAMO: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
@@ -2837,6 +2923,41 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_CHEXP,
|
||||
MODEL_TENSOR.FFN_UP_CHEXP,
|
||||
],
|
||||
MODEL_ARCH.MINIMAXM2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
],
|
||||
MODEL_ARCH.COGVLM: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.VISEXP_ATTN_QKV,
|
||||
MODEL_TENSOR.VISEXP_ATTN_OUT,
|
||||
MODEL_TENSOR.VISEXP_GATE,
|
||||
MODEL_TENSOR.VISEXP_UP,
|
||||
MODEL_TENSOR.VISEXP_DOWN,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
@@ -3055,6 +3176,7 @@ class VisionProjectorType:
|
||||
LLAMA4 = "llama4"
|
||||
QWEN2VL = "qwen2vl_merger"
|
||||
QWEN25VL = "qwen2.5vl_merger"
|
||||
QWEN3VL = "qwen3vl_merger"
|
||||
ULTRAVOX = "ultravox"
|
||||
INTERNVL = "internvl"
|
||||
QWEN2A = "qwen2a" # audio
|
||||
@@ -3062,6 +3184,9 @@ class VisionProjectorType:
|
||||
VOXTRAL = "voxtral"
|
||||
LFM2 = "lfm2"
|
||||
KIMIVL = "kimivl"
|
||||
LIGHTONOCR = "lightonocr"
|
||||
COGVLM = "cogvlm"
|
||||
JANUS_PRO = "janus_pro"
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
||||
@@ -860,6 +860,9 @@ class GGUFWriter:
|
||||
def add_pooling_type(self, value: PoolingType) -> None:
|
||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||
|
||||
def add_num_deepstack_layers(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.NUM_DEEPSTACK_LAYERS.format(arch=self.arch), count)
|
||||
|
||||
def add_rope_dimension_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
||||
|
||||
@@ -1071,6 +1074,9 @@ class GGUFWriter:
|
||||
def add_vision_n_wa_pattern(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
|
||||
|
||||
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
|
||||
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
|
||||
|
||||
# audio models
|
||||
|
||||
def add_audio_projection_dim(self, value: int) -> None:
|
||||
|
||||
@@ -104,6 +104,7 @@ class TensorNameMap:
|
||||
"backbone.final_layer_norm", # wavtokenizer
|
||||
"model.norm", # llama4
|
||||
"model.transformer.ln_f", # llada
|
||||
"model.norm", # cogvlm
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
@@ -162,6 +163,7 @@ class TensorNameMap:
|
||||
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
|
||||
"rwkv.blocks.{bid}.ln2", # rwkv6
|
||||
"model.layers.{bid}.ln2", # rwkv7
|
||||
"model.layers.{bid}.post_attention_layernorm", # cogvlm
|
||||
),
|
||||
|
||||
# Attention query-key-value
|
||||
@@ -184,6 +186,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
|
||||
"transformer.layers.{bid}.attn.qkv_proj", # openelm
|
||||
"transformer_encoder.{bid}.qkv", # neobert
|
||||
"model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm
|
||||
),
|
||||
|
||||
# Attention query
|
||||
@@ -279,6 +282,7 @@ class TensorNameMap:
|
||||
"model.transformer.blocks.{bid}.attn_out", # llada
|
||||
"layers.{bid}.self_attn.o_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.o_proj", # nemotron-h
|
||||
"model.layers.{bid}.self_attn.language_expert_dense", # cogvlm
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
@@ -377,6 +381,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
|
||||
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
|
||||
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
|
||||
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
|
||||
),
|
||||
|
||||
# Feed-forward up
|
||||
@@ -418,6 +423,7 @@ class TensorNameMap:
|
||||
"model.transformer.blocks.{bid}.up_proj", # llada
|
||||
"layers.{bid}.mlp.up_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.up_proj", # nemotron-h
|
||||
"model.layers.{bid}.mlp.language_mlp.up_proj", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
@@ -450,21 +456,22 @@ class TensorNameMap:
|
||||
|
||||
# Feed-forward gate
|
||||
MODEL_TENSOR.FFN_GATE: (
|
||||
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2
|
||||
"layers.{bid}.mlp.gate_proj", # embeddinggemma
|
||||
"layers.{bid}.feed_forward.w1", # llama-pth
|
||||
"transformer.h.{bid}.mlp.w2", # qwen
|
||||
"transformer.h.{bid}.mlp.c_fc2", # jais
|
||||
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
|
||||
"model.layers.{bid}.feed_forward.w1", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
|
||||
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used)
|
||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
|
||||
"model.transformer.blocks.{bid}.ff_proj", # llada
|
||||
"layers.{bid}.mlp.gate_proj", # qwen3-embedding
|
||||
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2
|
||||
"layers.{bid}.mlp.gate_proj", # embeddinggemma
|
||||
"layers.{bid}.feed_forward.w1", # llama-pth
|
||||
"transformer.h.{bid}.mlp.w2", # qwen
|
||||
"transformer.h.{bid}.mlp.c_fc2", # jais
|
||||
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
|
||||
"model.layers.{bid}.feed_forward.w1", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
|
||||
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used)
|
||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
|
||||
"model.transformer.blocks.{bid}.ff_proj", # llada
|
||||
"layers.{bid}.mlp.gate_proj", # qwen3-embedding
|
||||
"model.layers.{bid}.mlp.language_mlp.gate_proj", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
@@ -522,6 +529,7 @@ class TensorNameMap:
|
||||
"model.transformer.blocks.{bid}.ff_out", # llada
|
||||
"layers.{bid}.mlp.down_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.down_proj", # nemotron-h
|
||||
"model.layers.{bid}.mlp.language_mlp.down_proj", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
@@ -1047,6 +1055,26 @@ class TensorNameMap:
|
||||
"encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.VISEXP_UP: (
|
||||
"model.layers.{bid}.mlp.vision_mlp.up_proj", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.VISEXP_GATE: (
|
||||
"model.layers.{bid}.mlp.vision_mlp.gate_proj", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.VISEXP_DOWN: (
|
||||
"model.layers.{bid}.mlp.vision_mlp.down_proj", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.VISEXP_ATTN_OUT: (
|
||||
"model.layers.{bid}.self_attn.vision_expert_dense", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.VISEXP_ATTN_QKV: (
|
||||
"model.layers.{bid}.self_attn.vision_expert_query_key_value", # cogvlm
|
||||
),
|
||||
|
||||
############################################################################
|
||||
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: (
|
||||
@@ -1148,12 +1176,14 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_FC: (
|
||||
"model.connector.modality_projection.proj", # SmolVLM
|
||||
"model.vision.linear_proj.linear_proj", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_MLP: (
|
||||
"model.mm_projector.mlp.mlp.{bid}",
|
||||
"vision_model.vision_adapter.mlp.fc{bid}", # llama 4
|
||||
"mlp1.{bid}", # InternVL
|
||||
"model.aligner.fc1.hidden_layers.{bid}", # Janus Pro
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_PEG: (
|
||||
@@ -1164,6 +1194,7 @@ class TensorNameMap:
|
||||
"vision_tower.vision_model.embeddings.class_embedding",
|
||||
"model.vision_tower.embeddings.cls_token", # Intern-S1
|
||||
"vision_model.class_embedding", # llama 4
|
||||
"model.vision.patch_embedding.cls_embedding", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
|
||||
@@ -1176,6 +1207,7 @@ class TensorNameMap:
|
||||
"vision_model.patch_embedding.linear", # llama 4
|
||||
"visual.patch_embed.proj", # qwen2vl
|
||||
"vision_tower.patch_embed.proj", # kimi-vl
|
||||
"model.vision.patch_embedding.proj", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS: (
|
||||
@@ -1185,6 +1217,13 @@ class TensorNameMap:
|
||||
"model.vision_model.embeddings.position_embedding", # SmolVLM
|
||||
"vision_model.positional_embedding_vlm", # llama 4
|
||||
"vision_tower.patch_embed.pos_emb", # kimi-vl
|
||||
"visual.pos_embed", # qwen3vl
|
||||
"model.vision.patch_embedding.position_embedding", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_QKV: (
|
||||
"visual.blocks.{bid}.attn.qkv", # qwen3vl
|
||||
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: (
|
||||
@@ -1244,6 +1283,7 @@ class TensorNameMap:
|
||||
"vision_model.model.layers.{bid}.input_layernorm", # llama4
|
||||
"visual.blocks.{bid}.norm1", # qwen2vl
|
||||
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
|
||||
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_O: (
|
||||
@@ -1252,11 +1292,13 @@ class TensorNameMap:
|
||||
"model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1
|
||||
"vpm.encoder.layers.{bid}.self_attn.out_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.projection_layer", # Janus Pro
|
||||
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
|
||||
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
|
||||
"visual.blocks.{bid}.attn.proj", # qwen2vl
|
||||
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
|
||||
"model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
|
||||
@@ -1270,6 +1312,7 @@ class TensorNameMap:
|
||||
"vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
|
||||
"visual.blocks.{bid}.norm2", # qwen2vl
|
||||
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
|
||||
"model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: (
|
||||
@@ -1282,7 +1325,9 @@ class TensorNameMap:
|
||||
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
|
||||
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
|
||||
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
|
||||
"visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
|
||||
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
|
||||
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE: (
|
||||
@@ -1301,7 +1346,9 @@ class TensorNameMap:
|
||||
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
|
||||
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
|
||||
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
|
||||
"visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
|
||||
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
|
||||
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1: (
|
||||
@@ -1338,6 +1385,7 @@ class TensorNameMap:
|
||||
"multi_modal_projector.layer_norm",
|
||||
"multi_modal_projector.pre_norm",
|
||||
"pre_mm_projector_norm",
|
||||
"model.vision.linear_proj.norm1", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
|
||||
@@ -1397,6 +1445,42 @@ class TensorNameMap:
|
||||
"patch_merger.merging_layer", # mistral
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_DS_NORM: (
|
||||
"model.visual.deepstack_merger_list.{bid}.norm", # deepstack in qwen3vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_DS_FC1: (
|
||||
"model.visual.deepstack_merger_list.{bid}.linear_fc1", # deepstack in qwen3vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_DS_FC2: (
|
||||
"model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_POST_FC_NORM: (
|
||||
"model.vision.linear_proj.norm1", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_UP: (
|
||||
"model.vision.linear_proj.dense_h_to_4h", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_DOWN: (
|
||||
"model.vision.linear_proj.dense_4h_to_h", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_GATE: (
|
||||
"model.vision.linear_proj.gate_proj", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_TOK_BOI: (
|
||||
"model.vision.boi", # cogvlm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_TOK_EOI: (
|
||||
"model.vision.eoi", # cogvlm
|
||||
),
|
||||
|
||||
# audio (mtmd)
|
||||
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS: (
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user