mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-08 01:54:10 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80a6cf6347 | ||
|
|
0750a59903 | ||
|
|
aa3b7a90b4 | ||
|
|
333f2595a3 | ||
|
|
53d7d21e61 | ||
|
|
eeee367de5 | ||
|
|
64fe17fbb8 | ||
|
|
c1b187688d | ||
|
|
b8a5cfd11a | ||
|
|
08416ebe7f | ||
|
|
b4e335d8dc | ||
|
|
d6fe40fa00 | ||
|
|
e14e842e87 | ||
|
|
647b960bd8 | ||
|
|
299f5d782c | ||
|
|
ac76d36201 |
18
.github/workflows/build.yml
vendored
18
.github/workflows/build.yml
vendored
@@ -161,15 +161,16 @@ jobs:
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
run: |
|
||||
DAWN_VERSION="v1.0.0"
|
||||
DAWN_VERSION="v2.0.0"
|
||||
DAWN_OWNER="reeselevine"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
curl -L -o artifact.tar.gz \
|
||||
curl -L -o artifact.zip \
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
mkdir dawn
|
||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
||||
unzip artifact.zip
|
||||
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -521,15 +522,16 @@ jobs:
|
||||
id: dawn-depends
|
||||
run: |
|
||||
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
|
||||
DAWN_VERSION="v1.0.0"
|
||||
DAWN_VERSION="v2.0.0"
|
||||
DAWN_OWNER="reeselevine"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
curl -L -o artifact.tar.gz \
|
||||
curl -L -o artifact.zip \
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
mkdir dawn
|
||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
||||
unzip artifact.zip
|
||||
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
|
||||
@@ -740,6 +740,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
exit(0);
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"-cl", "--cache-list"},
|
||||
"show list of models in cache",
|
||||
[](common_params &) {
|
||||
printf("model cache directory: %s\n", fs_get_cache_directory().c_str());
|
||||
auto models = common_list_cached_models();
|
||||
printf("number of models in cache: %zu\n", models.size());
|
||||
for (size_t i = 0; i < models.size(); i++) {
|
||||
auto & model = models[i];
|
||||
printf("%4d. %s\n", (int) i + 1, model.to_string().c_str());
|
||||
}
|
||||
exit(0);
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--completion-bash"},
|
||||
"print source-able bash completion script for llama.cpp",
|
||||
|
||||
@@ -908,6 +908,39 @@ std::string fs_get_cache_file(const std::string & filename) {
|
||||
return cache_directory + filename;
|
||||
}
|
||||
|
||||
std::vector<common_file_info> fs_list_files(const std::string & path) {
|
||||
std::vector<common_file_info> files;
|
||||
if (path.empty()) return files;
|
||||
|
||||
std::filesystem::path dir(path);
|
||||
if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
|
||||
return files;
|
||||
}
|
||||
|
||||
for (const auto & entry : std::filesystem::directory_iterator(dir)) {
|
||||
try {
|
||||
// Only include regular files (skip directories)
|
||||
const auto & p = entry.path();
|
||||
if (std::filesystem::is_regular_file(p)) {
|
||||
common_file_info info;
|
||||
info.path = p.string();
|
||||
info.name = p.filename().string();
|
||||
try {
|
||||
info.size = static_cast<size_t>(std::filesystem::file_size(p));
|
||||
} catch (const std::filesystem::filesystem_error &) {
|
||||
info.size = 0;
|
||||
}
|
||||
files.push_back(std::move(info));
|
||||
}
|
||||
} catch (const std::filesystem::filesystem_error &) {
|
||||
// skip entries we cannot inspect
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return files;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Model utils
|
||||
|
||||
@@ -611,6 +611,13 @@ bool fs_create_directory_with_parents(const std::string & path);
|
||||
std::string fs_get_cache_directory();
|
||||
std::string fs_get_cache_file(const std::string & filename);
|
||||
|
||||
struct common_file_info {
|
||||
std::string path;
|
||||
std::string name;
|
||||
size_t size = 0; // in bytes
|
||||
};
|
||||
std::vector<common_file_info> fs_list_files(const std::string & path);
|
||||
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
|
||||
@@ -50,6 +50,22 @@ using json = nlohmann::ordered_json;
|
||||
// downloader
|
||||
//
|
||||
|
||||
// validate repo name format: owner/repo
|
||||
static bool validate_repo_name(const std::string & repo) {
|
||||
static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
|
||||
return std::regex_match(repo, repo_regex);
|
||||
}
|
||||
|
||||
static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
|
||||
// we use "=" to avoid clashing with other component, while still being allowed on windows
|
||||
std::string fname = "manifest=" + repo + "=" + tag + ".json";
|
||||
if (!validate_repo_name(repo)) {
|
||||
throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
|
||||
}
|
||||
string_replace_all(fname, "/", "=");
|
||||
return fs_get_cache_file(fname);
|
||||
}
|
||||
|
||||
static std::string read_file(const std::string & fname) {
|
||||
std::ifstream file(fname);
|
||||
if (!file) {
|
||||
@@ -829,17 +845,13 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
|
||||
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
|
||||
// User-Agent header is already set in common_remote_get_content, no need to set it here
|
||||
|
||||
// we use "=" to avoid clashing with other component, while still being allowed on windows
|
||||
std::string cached_response_fname = "manifest=" + hf_repo + "=" + tag + ".json";
|
||||
string_replace_all(cached_response_fname, "/", "_");
|
||||
std::string cached_response_path = fs_get_cache_file(cached_response_fname);
|
||||
|
||||
// make the request
|
||||
common_remote_params params;
|
||||
params.headers = headers;
|
||||
long res_code = 0;
|
||||
std::string res_str;
|
||||
bool use_cache = false;
|
||||
std::string cached_response_path = get_manifest_path(hf_repo, tag);
|
||||
if (!offline) {
|
||||
try {
|
||||
auto res = common_remote_get_content(url, params);
|
||||
@@ -895,6 +907,33 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
|
||||
return { hf_repo, ggufFile, mmprojFile };
|
||||
}
|
||||
|
||||
std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
std::vector<common_cached_model_info> models;
|
||||
const std::string cache_dir = fs_get_cache_directory();
|
||||
const std::vector<common_file_info> files = fs_list_files(cache_dir);
|
||||
for (const auto & file : files) {
|
||||
if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
|
||||
common_cached_model_info model_info;
|
||||
model_info.manifest_path = file.path;
|
||||
std::string fname = file.name;
|
||||
string_replace_all(fname, ".json", ""); // remove extension
|
||||
auto parts = string_split<std::string>(fname, '=');
|
||||
if (parts.size() == 4) {
|
||||
// expect format: manifest=<user>=<model>=<tag>=<other>
|
||||
model_info.user = parts[1];
|
||||
model_info.model = parts[2];
|
||||
model_info.tag = parts[3];
|
||||
} else {
|
||||
// invalid format
|
||||
continue;
|
||||
}
|
||||
model_info.size = 0; // TODO: get GGUF size, not manifest size
|
||||
models.push_back(model_info);
|
||||
}
|
||||
}
|
||||
return models;
|
||||
}
|
||||
|
||||
//
|
||||
// Docker registry functions
|
||||
//
|
||||
@@ -959,6 +998,7 @@ std::string common_docker_resolve_model(const std::string & docker) {
|
||||
std::string token = common_docker_get_token(repo); // Get authentication token
|
||||
|
||||
// Get manifest
|
||||
// TODO: cache the manifest response so that it appears in the model list
|
||||
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
|
||||
std::string manifest_url = url_prefix + "/manifests/" + tag;
|
||||
common_remote_params manifest_params;
|
||||
|
||||
@@ -8,16 +8,23 @@ struct common_params_model;
|
||||
// download functionalities
|
||||
//
|
||||
|
||||
struct common_cached_model_info {
|
||||
std::string manifest_path;
|
||||
std::string user;
|
||||
std::string model;
|
||||
std::string tag;
|
||||
size_t size = 0; // GGUF size in bytes
|
||||
std::string to_string() const {
|
||||
return user + "/" + model + ":" + tag;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_hf_file_res {
|
||||
std::string repo; // repo name with ":tag" removed
|
||||
std::string ggufFile;
|
||||
std::string mmprojFile;
|
||||
};
|
||||
|
||||
// resolve and download model from Docker registry
|
||||
// return local path to downloaded model file
|
||||
std::string common_docker_resolve_model(const std::string & docker);
|
||||
|
||||
/**
|
||||
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
|
||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
|
||||
@@ -39,3 +46,10 @@ bool common_download_model(
|
||||
const common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
bool offline);
|
||||
|
||||
// returns list of cached models
|
||||
std::vector<common_cached_model_info> common_list_cached_models();
|
||||
|
||||
// resolve and download model from Docker registry
|
||||
// return local path to downloaded model file
|
||||
std::string common_docker_resolve_model(const std::string & docker);
|
||||
|
||||
@@ -168,7 +168,7 @@ option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
|
||||
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
|
||||
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
|
||||
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
|
||||
option(GGML_VXE "ggml: enable vxe" ON)
|
||||
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
|
||||
|
||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
|
||||
@@ -124,6 +124,7 @@ if (CUDAToolkit_FOUND)
|
||||
|
||||
if (GGML_CUDA_DEBUG)
|
||||
list(APPEND CUDA_FLAGS -lineinfo)
|
||||
add_compile_definitions(GGML_CUDA_DEBUG)
|
||||
endif()
|
||||
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
||||
|
||||
@@ -198,7 +198,7 @@ static void ggml_cpy_flt_cuda(
|
||||
if (transposed) {
|
||||
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
|
||||
int ne00n, ne01n, ne02n;
|
||||
if (nb00 < nb02) {
|
||||
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
|
||||
ne00n = ne00;
|
||||
ne01n = ne01;
|
||||
ne02n = ne02;
|
||||
@@ -206,8 +206,6 @@ static void ggml_cpy_flt_cuda(
|
||||
ne00n = ne00;
|
||||
ne01n = ne01*ne02;
|
||||
ne02n = 1;
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
||||
|
||||
@@ -27,7 +27,6 @@
|
||||
#include "ggml-cuda/mmq.cuh"
|
||||
#include "ggml-cuda/mmvf.cuh"
|
||||
#include "ggml-cuda/mmvq.cuh"
|
||||
#include "ggml-cuda/moe-expert-reduce.cuh"
|
||||
#include "ggml-cuda/norm.cuh"
|
||||
#include "ggml-cuda/opt-step-adamw.cuh"
|
||||
#include "ggml-cuda/opt-step-sgd.cuh"
|
||||
@@ -3152,8 +3151,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
|
||||
#ifdef GGML_CUDA_DEBUG
|
||||
const int nodes_fused = i - prev_i - 1;
|
||||
prev_i = i;
|
||||
@@ -3199,31 +3196,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_MUL) {
|
||||
int current_node = i + 1;
|
||||
int num_views = 0;
|
||||
int num_adds = 0;
|
||||
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
|
||||
num_views++;
|
||||
current_node++;
|
||||
}
|
||||
|
||||
while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_ADD &&
|
||||
num_adds < num_views - 1) {
|
||||
num_adds++;
|
||||
current_node++;
|
||||
}
|
||||
|
||||
if (num_adds == num_views - 1 && num_views > 0) {
|
||||
ggml_tensor * dst_node = cgraph->nodes[current_node - 1];
|
||||
if (ggml_cuda_should_use_moe_expert_reduce(cgraph, i, current_node)) {
|
||||
ggml_cuda_op_moe_expert_reduce(*cuda_ctx, node->src[0], node->src[1], dst_node);
|
||||
i += num_views + num_adds;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
int n_fuse = 0;
|
||||
ggml_op ops[8];
|
||||
@@ -3302,6 +3274,13 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
continue;
|
||||
}
|
||||
|
||||
// we don't support repeating adds
|
||||
if (bias_op == GGML_OP_ADD &&
|
||||
(!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
|
||||
!ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = up_n->src[0];
|
||||
const ggml_tensor * src1 = up_n->src[1];
|
||||
const ggml_tensor * ids = up_n->src[2];
|
||||
@@ -3411,6 +3390,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
continue;
|
||||
}
|
||||
|
||||
if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.x_bias = bias_tensor;
|
||||
|
||||
|
||||
@@ -3494,7 +3494,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
||||
const int col_diff = col_high - col_low;
|
||||
|
||||
for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
|
||||
ids_dst_shared[j] = ids_dst[col_low + j];
|
||||
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
#include "moe-expert-reduce.cuh"
|
||||
|
||||
// This kernel is a fusion of the expert weight reduce, common in MoE models
|
||||
|
||||
template <int n_expert_used_template>
|
||||
__global__ void moe_expert_reduce_cuda(const float * __restrict__ experts,
|
||||
const float * __restrict__ weights,
|
||||
float * __restrict__ dst,
|
||||
const int n_expert_used,
|
||||
const int n_cols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
if (col >= n_cols) {
|
||||
return;
|
||||
}
|
||||
|
||||
experts += row * n_cols * n_expert_used;
|
||||
weights += row * n_expert_used;
|
||||
dst += row * n_cols;
|
||||
|
||||
float acc = 0.f;
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
for (int expert = 0; expert < n_expert_used; ++expert) {
|
||||
ggml_cuda_mad(acc, experts[col], weights[expert]);
|
||||
experts += n_cols;
|
||||
}
|
||||
dst[col] = acc;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_expert_used_template; ++i) {
|
||||
ggml_cuda_mad(acc, experts[col], weights[i]);
|
||||
experts += n_cols;
|
||||
}
|
||||
dst[col] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
static void launch_moe_expert_reduce(ggml_backend_cuda_context & ctx,
|
||||
const float * experts,
|
||||
const float * weights,
|
||||
float * dst,
|
||||
const int n_expert_used,
|
||||
const int n_cols,
|
||||
const int n_rows) {
|
||||
const int block_size = 32;
|
||||
|
||||
const int n_blocks_x = n_rows;
|
||||
const int n_blocks_y = (n_cols + block_size - 1) / block_size;
|
||||
|
||||
dim3 block_dims(block_size);
|
||||
dim3 grid_dims(n_blocks_x, n_blocks_y);
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
switch (n_expert_used) {
|
||||
case 1:
|
||||
moe_expert_reduce_cuda<1>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 2:
|
||||
moe_expert_reduce_cuda<2>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 4:
|
||||
moe_expert_reduce_cuda<4>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 6:
|
||||
moe_expert_reduce_cuda<6>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 8:
|
||||
moe_expert_reduce_cuda<8>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 16:
|
||||
moe_expert_reduce_cuda<16>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 32:
|
||||
moe_expert_reduce_cuda<32>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 64:
|
||||
moe_expert_reduce_cuda<64>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
case 128:
|
||||
moe_expert_reduce_cuda<128>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
default:
|
||||
moe_expert_reduce_cuda<0>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(experts, weights, dst, n_expert_used, n_cols);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index) {
|
||||
const ggml_tensor * mul = cgraph->nodes[start_index];
|
||||
|
||||
if (mul->op != GGML_OP_MUL || !ggml_is_contiguous(mul->src[0]) || !ggml_is_contiguous(mul->src[1])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int current_node = start_index + 1;
|
||||
size_t current_offset = 0;
|
||||
|
||||
std::vector<const ggml_tensor *> view_nodes;
|
||||
//check if all are views of the expert in increasing order
|
||||
while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_VIEW) {
|
||||
const ggml_tensor * node = cgraph->nodes[current_node];
|
||||
if (node->view_src != mul) {
|
||||
return false;
|
||||
}
|
||||
if (node->view_offs < current_offset) {
|
||||
return false;
|
||||
}
|
||||
current_offset = node->view_offs;
|
||||
current_node++;
|
||||
view_nodes.push_back(node);
|
||||
}
|
||||
|
||||
//check if all the adds are in increasing order
|
||||
const ggml_tensor * prev_add_src = view_nodes.empty() ? nullptr : view_nodes[0];
|
||||
int num_adds = 0;
|
||||
int num_views = view_nodes.size();
|
||||
while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_ADD) {
|
||||
const ggml_tensor * add_node = cgraph->nodes[current_node];
|
||||
|
||||
bool is_first_op_ok = num_views > num_adds ? add_node->src[0] == prev_add_src : false;
|
||||
bool is_second_op_ok = num_views > num_adds ? add_node->src[1] == view_nodes[num_adds + 1] : false;
|
||||
|
||||
if (!is_first_op_ok || !is_second_op_ok) {
|
||||
return false;
|
||||
}
|
||||
prev_add_src = add_node;
|
||||
|
||||
num_adds++;
|
||||
current_node++;
|
||||
}
|
||||
|
||||
if (num_views != num_adds + 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * experts,
|
||||
const ggml_tensor * weights,
|
||||
ggml_tensor * dst) {
|
||||
const int n_rows = experts->ne[2];
|
||||
const int n_expert_used = experts->ne[1];
|
||||
const int n_cols = experts->ne[0];
|
||||
|
||||
GGML_ASSERT(experts->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(experts));
|
||||
GGML_ASSERT(ggml_is_contiguous(weights));
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float * experts_d = (const float *) experts->data;
|
||||
const float * weights_d = (const float *) weights->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
launch_moe_expert_reduce(ctx, experts_d, weights_d, dst_d, n_expert_used, n_cols, n_rows);
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
#include "common.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * experts,
|
||||
const ggml_tensor * weights,
|
||||
ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index);
|
||||
@@ -289,7 +289,7 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor,
|
||||
|
||||
// queue the copy operation into the queue of the Metal context
|
||||
// this will be queued at the end, after any currently ongoing GPU operations
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
|
||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||
|
||||
[encoder copyFromBuffer:buf_src
|
||||
@@ -300,6 +300,7 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor,
|
||||
|
||||
[encoder endEncoding];
|
||||
[cmd_buf commit];
|
||||
[buf_src release];
|
||||
|
||||
// do not wait here for completion
|
||||
//[cmd_buf waitUntilCompleted];
|
||||
@@ -330,7 +331,7 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te
|
||||
|
||||
// queue the copy operation into the queue of the Metal context
|
||||
// this will be queued at the end, after any currently ongoing GPU operations
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
|
||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||
|
||||
[encoder copyFromBuffer:bid_src.metal
|
||||
@@ -341,6 +342,7 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te
|
||||
|
||||
[encoder endEncoding];
|
||||
[cmd_buf commit];
|
||||
[buf_dst release];
|
||||
|
||||
// do not wait here for completion
|
||||
//[cmd_buf waitUntilCompleted];
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -62,14 +62,8 @@ layout(push_constant) uniform parameter {
|
||||
uint32_t nb3;
|
||||
|
||||
// fastdiv helper values
|
||||
uint32_t KWmp; uint32_t KWL;
|
||||
uint32_t KWKHmp; uint32_t KWKHL;
|
||||
uint32_t OWmp; uint32_t OWL;
|
||||
uint32_t OWOHmp; uint32_t OWOHL;
|
||||
#ifdef TRANSPOSE
|
||||
uint32_t s0mp; uint32_t s0L;
|
||||
uint32_t s1mp; uint32_t s1L;
|
||||
#endif
|
||||
}
|
||||
|
||||
p;
|
||||
@@ -84,6 +78,15 @@ layout(constant_id = 4) const uint TS_K = 8;
|
||||
layout(constant_id = 5) const uint use_collectives = 1;
|
||||
layout(constant_id = 6) const uint SHMEM_PAD = 4;
|
||||
|
||||
layout(constant_id = 7) const uint s0 = 1;
|
||||
layout(constant_id = 8) const uint s1 = 1;
|
||||
layout(constant_id = 9) const uint p0 = 0;
|
||||
layout(constant_id = 10) const uint p1 = 0;
|
||||
layout(constant_id = 11) const uint d0 = 1;
|
||||
layout(constant_id = 12) const uint d1 = 1;
|
||||
layout(constant_id = 13) const uint KW = 1;
|
||||
layout(constant_id = 14) const uint KH = 1;
|
||||
|
||||
uint32_t tid = gl_LocalInvocationID.x;
|
||||
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
|
||||
|
||||
@@ -92,7 +95,7 @@ uint splitWork(uint work_size, uint block_size) {
|
||||
}
|
||||
|
||||
uint32_t K = p.Cout;
|
||||
uint32_t CRS = p.Cin * p.KH * p.KW;
|
||||
uint32_t CRS = p.Cin * KH * KW;
|
||||
uint32_t NPQ = p.N * p.OH * p.OW;
|
||||
|
||||
uint32_t n_elems_out = K * NPQ;
|
||||
@@ -187,7 +190,7 @@ void main() {
|
||||
}
|
||||
#endif
|
||||
/* Advance block in CRS dim */
|
||||
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
|
||||
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
|
||||
uint32_t CRS_idx_a;
|
||||
uint32_t Cin_idx_a;
|
||||
uint32_t KH_idx_a;
|
||||
@@ -200,10 +203,10 @@ void main() {
|
||||
uint32_t cached_KW_idx;
|
||||
if (use_collectives == 1) {
|
||||
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
|
||||
cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
|
||||
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
|
||||
cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
|
||||
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
|
||||
cached_Cin_idx = cached_CRS_idx / (KW * KH);
|
||||
uint32_t cached_CRS_remainder = cached_CRS_idx % (KW * KH);
|
||||
cached_KH_idx = cached_CRS_remainder / KW;
|
||||
cached_KW_idx = cached_CRS_remainder % KW;
|
||||
|
||||
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
|
||||
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
|
||||
@@ -211,21 +214,21 @@ void main() {
|
||||
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
|
||||
} else {
|
||||
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
|
||||
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
|
||||
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
|
||||
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
|
||||
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
|
||||
Cin_idx_a = CRS_idx_a / (KW * KH);
|
||||
uint32_t CRS_remainder = CRS_idx_a % (KW * KH);
|
||||
KH_idx_a = CRS_remainder / KW;
|
||||
KW_idx_a = CRS_remainder % KW;
|
||||
}
|
||||
#else
|
||||
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
|
||||
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
|
||||
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
|
||||
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
|
||||
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
|
||||
Cin_idx_a = CRS_idx_a / (KW * KH);
|
||||
CRS_remainder = CRS_idx_a % (KW * KH);
|
||||
KH_idx_a = CRS_remainder / KW;
|
||||
KW_idx_a = CRS_remainder % KW;
|
||||
#endif
|
||||
|
||||
/* Load kernel to A_block: (BS_K x BS_CRS)*/
|
||||
for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
|
||||
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
|
||||
uint32_t B_ly = r_offset + Ar;
|
||||
uint32_t B_lx = Ac;
|
||||
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
|
||||
@@ -262,27 +265,27 @@ void main() {
|
||||
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
|
||||
} else {
|
||||
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
|
||||
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
|
||||
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
|
||||
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
|
||||
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
|
||||
Cin_idx_b = CRS_idx_b / (KW * KH);
|
||||
uint32_t CRS_remainder = CRS_idx_b % (KW * KH);
|
||||
KH_idx_b = CRS_remainder / KW;
|
||||
KW_idx_b = CRS_remainder % KW;
|
||||
}
|
||||
#else
|
||||
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
|
||||
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
|
||||
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
|
||||
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
|
||||
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
|
||||
Cin_idx_b = CRS_idx_b / (KW * KH);
|
||||
uint32_t CRS_remainder = CRS_idx_b % (KW * KH);
|
||||
KH_idx_b = CRS_remainder / KW;
|
||||
KW_idx_b = CRS_remainder % KW;
|
||||
#endif
|
||||
|
||||
#ifdef TRANSPOSE
|
||||
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1;
|
||||
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0;
|
||||
uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L);
|
||||
uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L);
|
||||
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * d1 + p1;
|
||||
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * d0 + p0;
|
||||
uint32_t H_idx = H_idx_x_s1 / s1;
|
||||
uint32_t W_idx = W_idx_x_s0 / s0;
|
||||
#else
|
||||
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
|
||||
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
|
||||
uint32_t H_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
|
||||
uint32_t W_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
|
||||
#endif
|
||||
uint32_t src_idx =
|
||||
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
|
||||
@@ -290,7 +293,7 @@ void main() {
|
||||
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
|
||||
|| H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
|
||||
#ifdef TRANSPOSE
|
||||
|| (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
|
||||
|| (H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0)
|
||||
#endif
|
||||
) {
|
||||
val = 0.0;
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "utils.glsl"
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
#include "rope_params.glsl"
|
||||
#endif
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
@@ -12,11 +15,16 @@ layout (push_constant) uniform parameter
|
||||
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
|
||||
uint misalign_offsets;
|
||||
float param1; float param2; int param3;
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
rope_params rope;
|
||||
#endif
|
||||
} p;
|
||||
|
||||
#if !RMS_NORM_ROPE_FUSION
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
#endif
|
||||
|
||||
// true if src0/src1 are the same shape and the indices can be reused without additional modulus
|
||||
layout(constant_id = 0) const bool norepeat = false;
|
||||
|
||||
@@ -49,6 +49,7 @@ layout (push_constant) uniform parameter
|
||||
uint batch_stride_d;
|
||||
|
||||
uint enable_bias;
|
||||
uint enable_scale;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint nei0;
|
||||
@@ -129,6 +130,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
}
|
||||
#ifdef MUL_MAT_ID
|
||||
if (p.enable_scale != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
|
||||
}
|
||||
#endif
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
}
|
||||
}
|
||||
@@ -171,6 +178,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
}
|
||||
#ifdef MUL_MAT_ID
|
||||
if (p.enable_scale != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
|
||||
}
|
||||
#endif
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
}
|
||||
}
|
||||
@@ -203,6 +216,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
}
|
||||
#ifdef MUL_MAT_ID
|
||||
if (p.enable_scale != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]);
|
||||
}
|
||||
#endif
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +100,6 @@ layout (push_constant) uniform parameter
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
|
||||
layout (constant_id = 1) const uint BM = 64;
|
||||
layout (constant_id = 2) const uint BN = 64;
|
||||
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
||||
layout (constant_id = 4) const uint WM = 32;
|
||||
layout (constant_id = 5) const uint WN = 32;
|
||||
layout (constant_id = 6) const uint WMITER = 2;
|
||||
@@ -109,6 +108,14 @@ layout (constant_id = 8) const uint TN = 2;
|
||||
layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
|
||||
layout (constant_id = 10) const uint WARP = 32;
|
||||
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
#define BK 32
|
||||
#define BK_STEP 4
|
||||
#else
|
||||
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
||||
#define BK_STEP 2
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#define SHMEM_STRIDE (BK / 2 + 4)
|
||||
#else
|
||||
@@ -244,8 +251,13 @@ void main() {
|
||||
}
|
||||
#else
|
||||
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
FLOAT_TYPE_VEC4 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC4 cache_b;
|
||||
#else
|
||||
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC2 cache_b;
|
||||
#endif
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
|
||||
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
|
||||
@@ -283,24 +295,41 @@ void main() {
|
||||
}
|
||||
}
|
||||
#else
|
||||
[[unroll]] for (uint i = 0; i < BK / 2; i++) {
|
||||
[[unroll]] for (uint i = 0; i < BK / BK_STEP; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint j = 0; j < TM; j++) {
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
cache_a[wsir * TM + j].xy = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i ];
|
||||
cache_a[wsir * TM + j].zw = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i + 1];
|
||||
#else
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
cache_b.xy = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i ];
|
||||
cache_b.zw = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i + 1];
|
||||
#else
|
||||
cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
|
||||
#endif
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
|
||||
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
|
||||
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y),
|
||||
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));
|
||||
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),
|
||||
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));
|
||||
#else
|
||||
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
|
||||
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,32 @@
|
||||
#include "generic_binary_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
|
||||
// data is passed from rms_norm -> rope through shared memory.
|
||||
// rms_norm calls this data_d, rope calls this rope_data_a.
|
||||
// Binding 2 is not used
|
||||
shared FLOAT_TYPE rope_data_a[1024];
|
||||
#define data_d rope_data_a
|
||||
|
||||
layout (binding = 3) readonly buffer R_Y {int rope_data_pos[];};
|
||||
layout (binding = 4) readonly buffer R_Z {float rope_data_ff[];};
|
||||
layout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];};
|
||||
layout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows
|
||||
|
||||
#include "rope_params.glsl"
|
||||
#include "rope_funcs.glsl"
|
||||
|
||||
#define GGML_ROPE_TYPE_NORMAL 0
|
||||
#define GGML_ROPE_TYPE_NEOX 2
|
||||
#define GGML_ROPE_TYPE_MROPE 8
|
||||
#define GGML_ROPE_TYPE_VISION 24
|
||||
|
||||
#endif
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
@@ -28,8 +54,12 @@ void rms_norm(uint num_iters) {
|
||||
|
||||
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
// Per-row offset in shared memory
|
||||
uint32_t d_offset = 0;
|
||||
#else
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
#endif
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
@@ -79,6 +109,18 @@ void rms_norm(uint num_iters) {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
barrier();
|
||||
rope_params rp = p.rope;
|
||||
uint rope_row = (samp*nchannels + channel)*nrows + row;
|
||||
for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
|
||||
if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
|
||||
rope_neox(t, rope_row, rp);
|
||||
} else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
|
||||
rope_norm(t, rope_row, rp);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void main() {
|
||||
|
||||
227
ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
Normal file
227
ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
Normal file
@@ -0,0 +1,227 @@
|
||||
|
||||
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
}
|
||||
|
||||
uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
|
||||
#if RMS_NORM_ROPE_FUSION
|
||||
// Per-row offset in shared memory
|
||||
const uint ix = i0;
|
||||
#else
|
||||
const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
|
||||
#endif
|
||||
return ix;
|
||||
}
|
||||
|
||||
void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta, rope_params p) {
|
||||
float mscale = p.attn_factor;
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = p.freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
if (p.ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
|
||||
}
|
||||
// Backprogagation uses inverted rotation
|
||||
if (p.is_back != 0) {
|
||||
theta = -theta;
|
||||
}
|
||||
cos_theta = cos(theta) * mscale;
|
||||
sin_theta = sin(theta) * mscale;
|
||||
}
|
||||
|
||||
void rope_norm(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
uint idst = i1*ne0 + i0;
|
||||
const uint ix = rope_a_coord(i0, i01, i02, p);
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = i01*ne0 + i0;
|
||||
idst += rope_data_i[i02].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]);
|
||||
rope_data_d[idst + 1] = ROPE_D_TYPE(rope_data_a[ix + 1]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
|
||||
|
||||
const float x0 = float(rope_data_a[ix + 0]);
|
||||
const float x1 = float(rope_data_a[ix + 1]);
|
||||
|
||||
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
|
||||
void rope_neox(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
uint idst = i1*ne0 + i0/2;
|
||||
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = i01*ne0 + i0/2;
|
||||
idst += rope_data_i[i02].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
|
||||
rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
|
||||
|
||||
const float x0 = float(rope_data_a[ix + 0]);
|
||||
const float x1 = float(rope_data_a[ix + p.n_dims/2]);
|
||||
|
||||
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
|
||||
|
||||
void rope_multi(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
uint ne2 = p.ne02;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
const uint idst = i1*ne0 + i0/2;
|
||||
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
|
||||
rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
|
||||
const int sec_w = p.sections[1] + p.sections[0];
|
||||
const uint sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (p.is_imrope != 0) {
|
||||
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
|
||||
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
|
||||
} else {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < p.sections[0]) {
|
||||
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + p.sections[2]) {
|
||||
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
|
||||
|
||||
const float x0 = float(rope_data_a[ix + 0]);
|
||||
const float x1 = float(rope_data_a[ix + p.n_dims/2]);
|
||||
|
||||
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
|
||||
void rope_vision(const uint i0, const uint i1, rope_params p) {
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
uint ne2 = p.ne02;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = i1 % ne1;
|
||||
const uint i02 = i1 / ne1;
|
||||
|
||||
const uint idst = i1*ne0 + i0/2;
|
||||
const uint ix = rope_a_coord(i0/2, i01, i02, p);
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1];
|
||||
const int sec_w = p.sections[1] + p.sections[0];
|
||||
const uint sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < p.sections[0]) {
|
||||
const uint p0 = sector;
|
||||
theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
const uint p0 = sector - p.sections[0];
|
||||
theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
|
||||
}
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
|
||||
|
||||
const float x0 = float(rope_data_a[ix + 0]);
|
||||
const float x1 = float(rope_data_a[ix + p.n_dims]);
|
||||
|
||||
rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
rope_data_d[idst + p.n_dims] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
}
|
||||
|
||||
@@ -3,56 +3,18 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "rope_params.glsl"
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {int data_pos[];};
|
||||
layout (binding = 2) readonly buffer Z {float data_ff[];};
|
||||
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
|
||||
layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
|
||||
layout (binding = 0) readonly buffer X {A_TYPE rope_data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {int rope_data_pos[];};
|
||||
layout (binding = 2) readonly buffer Z {float rope_data_ff[];};
|
||||
layout (binding = 3) writeonly buffer D {ROPE_D_TYPE rope_data_d[];};
|
||||
layout (binding = 4) readonly buffer I {uvec2 rope_data_i[];}; // indices for set_rows
|
||||
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
uint n_dims;
|
||||
float freq_scale;
|
||||
uint p_delta_rows;
|
||||
float freq_base;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float corr_dims[2];
|
||||
float theta_scale;
|
||||
uint has_ff;
|
||||
uint ne02;
|
||||
uint s1;
|
||||
uint s2;
|
||||
int sections[4];
|
||||
uint is_imrope;
|
||||
uint is_back;
|
||||
uint set_rows_stride;
|
||||
} p;
|
||||
rope_params pc;
|
||||
};
|
||||
|
||||
float rope_yarn_ramp(const float low, const float high, const uint i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
}
|
||||
|
||||
void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
|
||||
float mscale = p.attn_factor;
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = p.freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
if (p.ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
|
||||
}
|
||||
// Backprogagation uses inverted rotation
|
||||
if (p.is_back != 0) {
|
||||
theta = -theta;
|
||||
}
|
||||
cos_theta = cos(theta) * mscale;
|
||||
sin_theta = sin(theta) * mscale;
|
||||
}
|
||||
|
||||
@@ -1,70 +1,11 @@
|
||||
#version 450
|
||||
|
||||
#include "rope_head.glsl"
|
||||
#include "rope_funcs.glsl"
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
uint ne2 = p.ne02;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
||||
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
|
||||
const int sec_w = p.sections[1] + p.sections[0];
|
||||
const uint sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (p.is_imrope != 0) {
|
||||
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
|
||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
||||
} else {
|
||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < p.sections[0]) {
|
||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = float(data_a[ix + 0]);
|
||||
const float x1 = float(data_a[ix + p.n_dims/2]);
|
||||
|
||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x;
|
||||
rope_multi(i0, i1, pc);
|
||||
}
|
||||
|
||||
@@ -1,48 +1,11 @@
|
||||
#version 450
|
||||
|
||||
#include "rope_head.glsl"
|
||||
#include "rope_funcs.glsl"
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = row_x*ne0 + i0/2;
|
||||
idst += data_i[channel_x].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]);
|
||||
data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = float(data_a[ix + 0]);
|
||||
const float x1 = float(data_a[ix + p.n_dims/2]);
|
||||
|
||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x;
|
||||
rope_neox(i0, i1, pc);
|
||||
}
|
||||
|
||||
@@ -1,48 +1,11 @@
|
||||
#version 450
|
||||
|
||||
#include "rope_head.glsl"
|
||||
#include "rope_funcs.glsl"
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
uint idst = row_dst*ne0 + i0;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
||||
|
||||
// Fusion optimization: ROPE + VIEW + SET_ROWS..
|
||||
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
|
||||
if (p.set_rows_stride != 0) {
|
||||
idst = row_x*ne0 + i0;
|
||||
idst += data_i[channel_x].x * p.set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + 0] = D_TYPE(data_a[ix + 0]);
|
||||
data_d[idst + 1] = D_TYPE(data_a[ix + 1]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = float(data_a[ix + 0]);
|
||||
const float x1 = float(data_a[ix + 1]);
|
||||
|
||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x;
|
||||
rope_norm(i0, i1, pc);
|
||||
}
|
||||
|
||||
27
ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
Normal file
27
ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
Normal file
@@ -0,0 +1,27 @@
|
||||
#if !defined(GGML_ROPE_PARAMS)
|
||||
#define GGML_ROPE_PARAMS
|
||||
|
||||
#include "rte.glsl"
|
||||
|
||||
struct rope_params {
|
||||
uint rope_mode;
|
||||
uint ncols;
|
||||
uint n_dims;
|
||||
float freq_scale;
|
||||
uint p_delta_rows;
|
||||
float freq_base;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float corr_dims[2];
|
||||
float theta_scale;
|
||||
uint has_ff;
|
||||
uint ne02;
|
||||
uint nb01;
|
||||
uint nb02;
|
||||
int sections[4];
|
||||
uint is_imrope;
|
||||
uint is_back;
|
||||
uint set_rows_stride;
|
||||
};
|
||||
|
||||
#endif // !defined(GGML_ROPE_PARAMS)
|
||||
@@ -1,47 +1,11 @@
|
||||
#version 450
|
||||
|
||||
#include "rope_head.glsl"
|
||||
#include "rope_funcs.glsl"
|
||||
|
||||
void main() {
|
||||
const uint i0 = 2*gl_GlobalInvocationID.y;
|
||||
uint ne0 = p.ncols;
|
||||
uint ne1 = p.p_delta_rows;
|
||||
uint ne2 = p.ne02;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1];
|
||||
const int sec_w = p.sections[1] + p.sections[0];
|
||||
const uint sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < p.sections[0]) {
|
||||
const uint p0 = sector;
|
||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, p0);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
const uint p0 = sector - p.sections[0];
|
||||
theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0);
|
||||
}
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = float(data_a[ix + 0]);
|
||||
const float x1 = float(data_a[ix + p.n_dims]);
|
||||
|
||||
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
|
||||
data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta);
|
||||
// i1 is actually i2*nb2+i1, but the rows are contiguous
|
||||
const uint i1 = gl_GlobalInvocationID.x;
|
||||
rope_vision(i0, i1, pc);
|
||||
}
|
||||
|
||||
@@ -695,6 +695,8 @@ void process_shaders() {
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
|
||||
string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
@@ -840,25 +842,25 @@ void process_shaders() {
|
||||
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
|
||||
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||
|
||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include <condition_variable>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
@@ -73,6 +74,30 @@
|
||||
// For operations which process a row in parallel, this seems like a reasonable default
|
||||
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
|
||||
|
||||
// Matrix multiplication parameters
|
||||
|
||||
// Register tiling parameters
|
||||
#define WEBGPU_MUL_MAT_TILE_M 8
|
||||
#define WEBGPU_MUL_MAT_TILE_N 8
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
|
||||
#define WEBGPU_MUL_MAT_TILE_K 32
|
||||
|
||||
// Subgroup matrix parameters
|
||||
// The number of subgroups in the M dimension
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
|
||||
// The number of subgroups in the N dimension
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_N 2
|
||||
// The number of subgroup matrices each subgroup accumulates over
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
||||
|
||||
// Matrix-vector multiplication parameters
|
||||
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||
// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
|
||||
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
|
||||
|
||||
/* End Constants */
|
||||
|
||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
|
||||
@@ -236,6 +261,10 @@ struct webgpu_context_struct {
|
||||
wgpu::Queue queue;
|
||||
wgpu::Limits limits;
|
||||
|
||||
bool supports_subgroup_matrix = false;
|
||||
uint32_t subgroup_size;
|
||||
wgpu::SubgroupMatrixConfig subgroup_matrix_config;
|
||||
|
||||
// Separate this out from limits since on some Metal systems, the limit returned by
|
||||
// querying the limits is higher than the actual allowed maximum.
|
||||
uint32_t max_wg_size_x;
|
||||
@@ -247,6 +276,11 @@ struct webgpu_context_struct {
|
||||
webgpu_buf_pool set_rows_error_buf_pool;
|
||||
|
||||
webgpu_pipeline memset_pipeline;
|
||||
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
|
||||
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
|
||||
|
||||
webgpu_pipeline mul_mat_pipeline[30][2];
|
||||
webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized
|
||||
webgpu_pipeline get_rows_pipeline[30];
|
||||
@@ -321,6 +355,25 @@ struct ggml_backend_webgpu_buffer_context {
|
||||
|
||||
/* WebGPU object initializations */
|
||||
|
||||
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
|
||||
// the corresponding values provided in `repls`.
|
||||
static std::string ggml_webgpu_process_shader_repls(const char * src,
|
||||
const std::map<std::string, std::string> & repls) {
|
||||
if (!src) {
|
||||
return std::string();
|
||||
}
|
||||
std::string s = src;
|
||||
for (const auto & kv : repls) {
|
||||
std::string token = "{{" + kv.first + "}}";
|
||||
size_t pos = 0;
|
||||
while ((pos = s.find(token, pos)) != std::string::npos) {
|
||||
s.replace(pos, token.length(), kv.second);
|
||||
pos += kv.second.length();
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
static void ggml_webgpu_create_pipeline(wgpu::Device & device,
|
||||
webgpu_pipeline & pipeline,
|
||||
const char * shader_code,
|
||||
@@ -346,6 +399,30 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &
|
||||
pipeline = { device.CreateComputePipeline(&pipeline_desc), label };
|
||||
}
|
||||
|
||||
static webgpu_pipeline ggml_webgpu_create_pipeline2(wgpu::Device & device,
|
||||
const char * shader_code,
|
||||
const char * label,
|
||||
const std::vector<wgpu::ConstantEntry> & constants = {}) {
|
||||
wgpu::ShaderSourceWGSL shader_source;
|
||||
shader_source.code = shader_code;
|
||||
|
||||
wgpu::ShaderModuleDescriptor shader_desc;
|
||||
shader_desc.nextInChain = &shader_source;
|
||||
|
||||
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
|
||||
|
||||
wgpu::ComputePipelineDescriptor pipeline_desc;
|
||||
pipeline_desc.label = label;
|
||||
pipeline_desc.compute.module = shader_module;
|
||||
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
|
||||
pipeline_desc.layout = nullptr; // nullptr means auto layout
|
||||
if (constants.size() > 0) {
|
||||
pipeline_desc.compute.constants = constants.data();
|
||||
pipeline_desc.compute.constantCount = constants.size();
|
||||
}
|
||||
return { device.CreateComputePipeline(&pipeline_desc), label };
|
||||
}
|
||||
|
||||
static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
||||
wgpu::Buffer & buffer,
|
||||
size_t size,
|
||||
@@ -512,6 +589,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context &
|
||||
std::vector<uint32_t> params,
|
||||
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
||||
uint32_t wg_x,
|
||||
uint32_t wg_y = 1,
|
||||
std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
|
||||
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
||||
|
||||
@@ -557,7 +635,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context &
|
||||
#endif
|
||||
pass.SetPipeline(pipeline.pipeline);
|
||||
pass.SetBindGroup(0, bind_group);
|
||||
pass.DispatchWorkgroups(wg_x, 1, 1);
|
||||
pass.DispatchWorkgroups(wg_x, wg_y, 1);
|
||||
pass.End();
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
@@ -779,7 +857,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
||||
|
||||
uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size;
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs);
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
||||
@@ -835,8 +913,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) dst->ne[1], // number of rows in result (M)
|
||||
(uint32_t) dst->ne[0], // number of columns in result (N)
|
||||
(uint32_t) dst->ne[0], // number of rows in result (M, transposed)
|
||||
(uint32_t) dst->ne[1], // number of columns in result (N)
|
||||
(uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
|
||||
@@ -865,9 +943,67 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->mul_mat_pipeline[src0->type][src1->type];
|
||||
|
||||
uint32_t wg_x =
|
||||
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
|
||||
return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
|
||||
uint32_t wg_y = 1;
|
||||
|
||||
bool use_fast = false;
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
use_fast = (src0->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
use_fast = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
if (use_fast) {
|
||||
int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
|
||||
if (dst->ne[1] == 1) {
|
||||
// We don't support vectorized mul_mat_vec for quantized types
|
||||
vectorized = vectorized && (src0->type < 2);
|
||||
pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups =
|
||||
(dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension;
|
||||
wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) /
|
||||
ctx->limits.maxComputeWorkgroupsPerDimension;
|
||||
} else {
|
||||
pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
|
||||
uint32_t wg_m;
|
||||
uint32_t wg_n;
|
||||
if (ctx->supports_subgroup_matrix) {
|
||||
// The total number of subgroups/workgroups needed per matrix.
|
||||
uint32_t wg_m_sg_tile =
|
||||
WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M;
|
||||
wg_m = (dst->ne[0] + wg_m_sg_tile - 1) / wg_m_sg_tile;
|
||||
uint32_t wg_n_sg_tile =
|
||||
WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
|
||||
wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile;
|
||||
} else {
|
||||
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||
wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s;
|
||||
wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s;
|
||||
}
|
||||
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
}
|
||||
}
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
@@ -1583,12 +1719,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_f32_f32, "mul_mat_f32_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
|
||||
wgsl_mul_mat_f16_f16, "mul_mat_f16_f16");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_f16_f32, "mul_mat_f16_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
|
||||
@@ -1627,6 +1757,136 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
|
||||
|
||||
if (webgpu_ctx->supports_subgroup_matrix) {
|
||||
std::map<std::string, std::string> sg_matrix_repls;
|
||||
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
|
||||
sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
|
||||
|
||||
std::string proc_mul_mat_subgroup_matrix_f32_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f32_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f16 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f16_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_q4_0_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
|
||||
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_f32_f32_vec");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_f16_f32_vec");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_f16_f16_vec");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_q4_0_f32_vec");
|
||||
} else {
|
||||
std::vector<wgpu::ConstantEntry> mul_mat_reg_tile_constants(3);
|
||||
mul_mat_reg_tile_constants[0].key = "TILE_K";
|
||||
mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K;
|
||||
mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M";
|
||||
mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N";
|
||||
mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||
|
||||
std::map<std::string, std::string> reg_repls;
|
||||
reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
|
||||
reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
|
||||
|
||||
// Process each reg-tile shader with tile replacements.
|
||||
// Keep the processed strings in-scope so .c_str() remains valid.
|
||||
std::string proc_mul_mat_reg_tile_f32_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f32_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f16 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f16_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_q4_0_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_q4_0_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
|
||||
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(),
|
||||
"mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(),
|
||||
"mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(),
|
||||
"mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(),
|
||||
"mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(),
|
||||
"mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(),
|
||||
"mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(),
|
||||
"mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(),
|
||||
"mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants);
|
||||
}
|
||||
|
||||
std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
|
||||
mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
|
||||
mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
mul_mat_vec_constants[1].key = "TILE_K";
|
||||
mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
|
||||
mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
|
||||
mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
|
||||
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
|
||||
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
@@ -2124,7 +2384,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
|
||||
webgpu_context ctx = reg_ctx->webgpu_ctx;
|
||||
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
// TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
|
||||
const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
|
||||
wgpu::DawnTogglesDescriptor adapterTogglesDesc;
|
||||
adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
|
||||
adapterTogglesDesc.enabledToggleCount = 2;
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
options.nextInChain = &adapterTogglesDesc;
|
||||
ctx->instance.WaitAny(ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
@@ -2140,12 +2406,46 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
ctx->adapter.GetLimits(&ctx->limits);
|
||||
ctx->max_wg_size_x = 288; // default value
|
||||
|
||||
wgpu::AdapterInfo info{};
|
||||
wgpu::AdapterInfo info{};
|
||||
wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
|
||||
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||
info.nextInChain = &subgroup_matrix_configs;
|
||||
}
|
||||
ctx->adapter.GetInfo(&info);
|
||||
|
||||
wgpu::SupportedFeatures features;
|
||||
ctx->adapter.GetFeatures(&features);
|
||||
// we require f16 support
|
||||
GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||
|
||||
// Only support square f16 matrices of size 8 or 16 for now
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
||||
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
||||
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
|
||||
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
||||
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
||||
ctx->subgroup_matrix_config = config;
|
||||
valid_subgroup_matrix_config = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
||||
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
||||
ctx->subgroup_size = info.subgroupMaxSize;
|
||||
ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||
|
||||
// Initialize device
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
||||
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
||||
if (ctx->supports_subgroup_matrix) {
|
||||
required_features.push_back(wgpu::FeatureName::Subgroups);
|
||||
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
required_features.push_back(wgpu::FeatureName::TimestampQuery);
|
||||
#endif
|
||||
|
||||
@@ -72,9 +72,12 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
||||
except ValueError:
|
||||
decls_map = {}
|
||||
|
||||
with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f:
|
||||
common_decls = f.read()
|
||||
decls_map.update(parse_decls(common_decls))
|
||||
for fname in sorted(os.listdir(input_dir)):
|
||||
if fname.endswith(".tmpl"):
|
||||
tmpl_path = os.path.join(input_dir, fname)
|
||||
with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
|
||||
decls = f_tmpl.read()
|
||||
decls_map.update(parse_decls(decls))
|
||||
|
||||
shader_template = extract_block(text, "SHADER")
|
||||
for variant in variants:
|
||||
|
||||
@@ -864,8 +864,8 @@ struct MulMatParams {
|
||||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // N rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed)
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
@@ -891,8 +891,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
|
||||
let dst2_rem = dst3_rem % dst2_stride;
|
||||
|
||||
let row = dst2_rem / params.n; // output row
|
||||
let col = dst2_rem % params.n; // output column
|
||||
let row = dst2_rem / params.m; // output row
|
||||
let col = dst2_rem % params.m; // output column
|
||||
|
||||
let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
|
||||
@@ -901,7 +901,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
|
||||
sum += multiply_add(src0_idx_base, src1_idx_base, i);
|
||||
}
|
||||
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
|
||||
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
||||
97
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
Normal file
97
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
Normal file
@@ -0,0 +1,97 @@
|
||||
#decl(SHMEM_VEC)
|
||||
fn store_shmem(val: vec4<f16>, idx: u32) {
|
||||
shmem[idx] = val.x;
|
||||
shmem[idx + 1] = val.y;
|
||||
shmem[idx + 2] = val.z;
|
||||
shmem[idx + 3] = val.w;
|
||||
}
|
||||
#enddecl(SHMEM_VEC)
|
||||
|
||||
#decl(SHMEM_SCALAR)
|
||||
fn store_shmem(val: f16, idx: u32) {
|
||||
shmem[idx] = val;
|
||||
}
|
||||
#enddecl(SHMEM_SCALAR)
|
||||
|
||||
#decl(INIT_SRC0_SHMEM_FLOAT)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let src0_val = select( // taking a slight performance hit to avoid oob
|
||||
{{SRC0_TYPE}}(0.0),
|
||||
src0[src0_idx/{{VEC_SIZE}}],
|
||||
global_m < params.m && global_k < params.k);
|
||||
store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(INIT_SRC0_SHMEM_FLOAT)
|
||||
|
||||
#decl(INIT_SRC1_SHMEM)
|
||||
|
||||
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
let tile_n = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_n = offset_n + tile_n;
|
||||
let global_k = k_outer + tile_k;
|
||||
let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
|
||||
let src1_val = select(
|
||||
{{SRC1_TYPE}}(0.0),
|
||||
src1[src1_idx/{{VEC_SIZE}}],
|
||||
global_n < params.n && global_k < params.k);
|
||||
store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(INIT_SRC1_SHMEM)
|
||||
|
||||
#decl(INIT_SRC0_SHMEM_Q4_0)
|
||||
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 1u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
shmem[shmem_idx + j * 2 + k] = q_lo;
|
||||
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(INIT_SRC0_SHMEM_Q4_0)
|
||||
247
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl
Normal file
247
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl
Normal file
@@ -0,0 +1,247 @@
|
||||
#define(VARIANTS)
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f32>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f32",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f16>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f16",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(VEC)
|
||||
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
|
||||
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
|
||||
}
|
||||
#enddecl(VEC)
|
||||
|
||||
#decl(SCALAR)
|
||||
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
|
||||
return f32(acc[tm][tn]);
|
||||
}
|
||||
#enddecl(SCALAR)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
enable f16;
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
m: u32,
|
||||
n: u32,
|
||||
k: u32,
|
||||
stride_01: u32,
|
||||
stride_11: u32,
|
||||
stride_02: u32,
|
||||
stride_12: u32,
|
||||
stride_03: u32,
|
||||
stride_13: u32,
|
||||
bs02: u32,
|
||||
bs03: u32,
|
||||
broadcast2: u32,
|
||||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
DECLS
|
||||
|
||||
fn get_local_n(thread_id: u32) -> u32 {
|
||||
return thread_id / WORKGROUP_SIZE_M;
|
||||
}
|
||||
fn get_local_m(thread_id: u32) -> u32 {
|
||||
return thread_id % WORKGROUP_SIZE_M;
|
||||
}
|
||||
|
||||
// TILE_M must be multiple of 4 for vec4 loads
|
||||
const TILE_M = {{WEBGPU_TILE_M}}u;
|
||||
const TILE_N = {{WEBGPU_TILE_N}}u;
|
||||
|
||||
override WORKGROUP_SIZE_M: u32;
|
||||
override WORKGROUP_SIZE_N: u32;
|
||||
override TILE_K: u32;
|
||||
|
||||
override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
|
||||
override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
|
||||
override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
|
||||
|
||||
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
|
||||
|
||||
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>) {
|
||||
|
||||
let thread_id = local_id.x;
|
||||
let local_m = get_local_m(thread_id);
|
||||
let local_n = get_local_n(thread_id);
|
||||
|
||||
let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N);
|
||||
let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
|
||||
let wg_per_matrix = wg_m_count * wg_n_count;
|
||||
|
||||
let batch_idx = wg_id.x / wg_per_matrix;
|
||||
|
||||
let wg_in_batch = wg_id.x % wg_per_matrix;
|
||||
let wg_m = wg_in_batch % wg_m_count;
|
||||
let wg_n = wg_in_batch / wg_m_count;
|
||||
|
||||
let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M;
|
||||
let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N;
|
||||
|
||||
let dst2_stride = params.m * params.n;
|
||||
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||
|
||||
let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
|
||||
let src03_idx = dst3_idx / params.broadcast3;
|
||||
let src13_idx = dst3_idx;
|
||||
let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
|
||||
let src02_idx = dst2_idx / params.broadcast2;
|
||||
let src12_idx = dst2_idx;
|
||||
|
||||
let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
|
||||
let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
|
||||
let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
|
||||
let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
|
||||
|
||||
var acc: array<array<f16, TILE_N>, TILE_M>;
|
||||
|
||||
for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
|
||||
|
||||
// see mul_mat_decls.tmpl
|
||||
init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
|
||||
init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
let k_end = min(TILE_K, params.k - k_outer);
|
||||
|
||||
for (var k_inner = 0u; k_inner < k_end; k_inner++) {
|
||||
var src0_tile: array<f16, TILE_M>;
|
||||
for (var tm = 0u; tm < TILE_M; tm++) {
|
||||
let src0_m = local_m * TILE_M + tm;
|
||||
let src0_idx = k_inner + src0_m * TILE_K;
|
||||
src0_tile[tm] = shmem[src0_idx];
|
||||
}
|
||||
for (var tn = 0u; tn < TILE_N; tn++) {
|
||||
let src1_n = local_n * TILE_N + tn;
|
||||
let src1_idx = src1_n * TILE_K + k_inner;
|
||||
let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
|
||||
for (var tm = 0u; tm < TILE_M; tm++) {
|
||||
acc[tm][tn] += src0_tile[tm] * src1_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
|
||||
|
||||
for (var tn = 0u; tn < TILE_N; tn++) {
|
||||
let global_col = output_col_base + tn;
|
||||
if (global_col < params.n) {
|
||||
for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
|
||||
let global_row = output_row_base + tm;
|
||||
if (global_row < params.m) {
|
||||
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
|
||||
dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
@@ -0,0 +1,302 @@
|
||||
#define(VARIANTS)
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f32>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f32",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f16>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f16",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(VEC)
|
||||
fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
||||
dst[dst_idx] = vec4<f32>(
|
||||
f32(shmem[shmem_idx]),
|
||||
f32(shmem[shmem_idx + 1]),
|
||||
f32(shmem[shmem_idx + 2]),
|
||||
f32(shmem[shmem_idx + 3])
|
||||
);
|
||||
}
|
||||
#enddecl(VEC)
|
||||
|
||||
#decl(SCALAR)
|
||||
fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
||||
dst[dst_idx] = f32(shmem[shmem_idx]);
|
||||
}
|
||||
#enddecl(SCALAR)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
m: u32,
|
||||
n: u32,
|
||||
k: u32,
|
||||
stride_01: u32,
|
||||
stride_11: u32,
|
||||
stride_02: u32,
|
||||
stride_12: u32,
|
||||
stride_03: u32,
|
||||
stride_13: u32,
|
||||
bs02: u32,
|
||||
bs03: u32,
|
||||
broadcast2: u32,
|
||||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
DECLS
|
||||
|
||||
// Note: These are string interpolated at build time, cannot use override constants due to limitations in
|
||||
// current Dawn version type definitions/matrix load requirements for constant memory sizes.
|
||||
const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
|
||||
const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
|
||||
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
|
||||
// runtime subgroup size is smaller.
|
||||
const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
|
||||
|
||||
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
|
||||
|
||||
const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
|
||||
const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
|
||||
const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
|
||||
|
||||
const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
|
||||
const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
|
||||
|
||||
const TILE_K = {{WEBGPU_TILE_K}}u;
|
||||
|
||||
const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||
const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||
|
||||
const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
|
||||
const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||
const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||
|
||||
const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE;
|
||||
|
||||
// We reuse shmem for accumulation matrices
|
||||
const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM);
|
||||
|
||||
var<workgroup> shmem: array<f16, SHMEM_SIZE>;
|
||||
|
||||
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32) {
|
||||
|
||||
let thread_id = local_id.x;
|
||||
let subgroup_m = subgroup_id % SUBGROUP_M;
|
||||
let subgroup_n = subgroup_id / SUBGROUP_M;
|
||||
|
||||
let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE;
|
||||
let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
|
||||
let wg_per_matrix = wg_m_count * wg_n_count;
|
||||
|
||||
let batch_idx = wg_id.x / wg_per_matrix;
|
||||
|
||||
let wg_in_batch = wg_id.x % wg_per_matrix;
|
||||
let wg_m = wg_in_batch % wg_m_count;
|
||||
let wg_n = wg_in_batch / wg_m_count;
|
||||
|
||||
let dst2_stride = params.m * params.n;
|
||||
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||
|
||||
let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
|
||||
let src03_idx = dst3_idx / params.broadcast3;
|
||||
let src13_idx = dst3_idx;
|
||||
let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
|
||||
let src02_idx = dst2_idx / params.broadcast2;
|
||||
let src12_idx = dst2_idx;
|
||||
|
||||
let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
|
||||
let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
|
||||
let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||
let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||
|
||||
var acc_sg_mat : array<array<subgroup_matrix_result<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>;
|
||||
|
||||
for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
|
||||
|
||||
// see mul_mat_decls.tmpl
|
||||
init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
|
||||
init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (subgroup_id < EXPECTED_SUBGROUPS) {
|
||||
|
||||
for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) {
|
||||
|
||||
let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner;
|
||||
var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>;
|
||||
for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
|
||||
src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>(
|
||||
&shmem,
|
||||
src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K,
|
||||
false,
|
||||
TILE_K
|
||||
);
|
||||
}
|
||||
|
||||
let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner;
|
||||
for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
|
||||
let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>(
|
||||
&shmem,
|
||||
src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K,
|
||||
true,
|
||||
TILE_K
|
||||
);
|
||||
for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
|
||||
acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
|
||||
|
||||
// Stage the subgroup matrix tiles into shared memory
|
||||
// This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile).
|
||||
let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE;
|
||||
let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||
let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||
|
||||
if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :(
|
||||
for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
|
||||
for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
|
||||
let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE;
|
||||
let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE;
|
||||
let out_base = local_row * WG_TILE_STRIDE + local_col;
|
||||
subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// Cooperative write: iterate over the entire workgroup tile
|
||||
let tile_rows = WG_N_SG_TILE_SIZE;
|
||||
let tile_cols = WG_M_SG_TILE_SIZE;
|
||||
let total_tile_elems = tile_rows * tile_cols;
|
||||
let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||
let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||
|
||||
for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
let local_row = idx % WG_TILE_STRIDE;
|
||||
let local_col = idx / WG_TILE_STRIDE;
|
||||
|
||||
let global_row = tile_dst_row_base + local_row;
|
||||
let global_col = tile_dst_col_base + local_col;
|
||||
|
||||
if (global_col < params.n && global_row < params.m) {
|
||||
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
|
||||
store_dst(idx, dst_idx/{{VEC_SIZE}});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
267
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl
Normal file
267
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl
Normal file
@@ -0,0 +1,267 @@
|
||||
#define(VARIANTS)
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f32>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f32",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f16>",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f16",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(VEC)
|
||||
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
|
||||
return f32(dot({{SRC1_TYPE}}(src0_val), src1_val));
|
||||
}
|
||||
|
||||
fn store_val(group_base: u32) -> vec4<f32> {
|
||||
return vec4<f32>(partial_sums[group_base],
|
||||
partial_sums[group_base + THREADS_PER_OUTPUT],
|
||||
partial_sums[group_base + THREADS_PER_OUTPUT * 2],
|
||||
partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
|
||||
}
|
||||
#enddecl(VEC)
|
||||
|
||||
#decl(SCALAR)
|
||||
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
|
||||
return f32(src0_val) * f32(src1_val);
|
||||
}
|
||||
|
||||
fn store_val(group_base: u32) -> f32 {
|
||||
return partial_sums[group_base];
|
||||
}
|
||||
#enddecl(SCALAR)
|
||||
|
||||
#decl(MUL_ACC_FLOAT)
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) {
|
||||
let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}];
|
||||
let b = shared_vector[i / {{VEC_SIZE}}];
|
||||
local_sum += inner_dot(a, b);
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
|
||||
#enddecl(MUL_ACC_FLOAT)
|
||||
|
||||
#decl(MUL_ACC_Q4_0)
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 1 + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f32(q_byte & 0xF) - 8.0) * d;
|
||||
local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
|
||||
local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
|
||||
}
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
|
||||
#enddecl(MUL_ACC_Q4_0)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
enable f16;
|
||||
|
||||
DECLS
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
m: u32,
|
||||
n: u32,
|
||||
k: u32,
|
||||
stride_01: u32,
|
||||
stride_11: u32,
|
||||
stride_02: u32,
|
||||
stride_12: u32,
|
||||
stride_03: u32,
|
||||
stride_13: u32,
|
||||
bs02: u32,
|
||||
bs03: u32,
|
||||
broadcast2: u32,
|
||||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K)
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // Result vector (transposed)
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
override WORKGROUP_SIZE: u32;
|
||||
override TILE_K: u32;
|
||||
override OUTPUTS_PER_WG: u32;
|
||||
override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
|
||||
|
||||
// Shared memory for collaborative loading and reduction
|
||||
var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile
|
||||
var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>; // For reduction
|
||||
|
||||
@compute @workgroup_size(WORKGROUP_SIZE)
|
||||
fn main(
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
let thread_id = local_id.x;
|
||||
|
||||
// Handle batch dimensions
|
||||
let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
|
||||
let batch_idx = wg_linear / output_groups;
|
||||
if (batch_idx >= total_batches) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Which of the outputs does this thread belong to?
|
||||
let thread_group = thread_id / THREADS_PER_OUTPUT;
|
||||
let thread_in_group = thread_id % THREADS_PER_OUTPUT;
|
||||
|
||||
// Each workgroup computes OUTPUTS_PER_WG consecutive outputs
|
||||
let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group;
|
||||
|
||||
let dst2_stride = params.m * params.n;
|
||||
let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
|
||||
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||
let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
|
||||
let src03_idx = dst3_idx / params.broadcast3;
|
||||
let src13_idx = dst3_idx;
|
||||
let src02_idx = dst2_idx / params.broadcast2;
|
||||
let src12_idx = dst2_idx;
|
||||
|
||||
let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row;
|
||||
|
||||
var local_sum = 0.0;
|
||||
|
||||
// Each thread processes multiple K elements and accumulates
|
||||
for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) {
|
||||
let tile_size = min(TILE_K, params.k - k_tile);
|
||||
|
||||
// Cooperatively load vector tile into shared memory (all threads)
|
||||
for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (output_row < params.m) {
|
||||
local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Store partial sums and reduce within each partition
|
||||
partial_sums[thread_id] = local_sum;
|
||||
workgroupBarrier();
|
||||
let group_base = thread_group * THREADS_PER_OUTPUT;
|
||||
let thread_base = group_base + thread_in_group;
|
||||
var offset = THREADS_PER_OUTPUT / 2;
|
||||
while (offset > 0) {
|
||||
if (thread_in_group < offset) {
|
||||
partial_sums[thread_base] += partial_sums[thread_base + offset];
|
||||
}
|
||||
offset = offset / 2;
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Store back to global memory
|
||||
if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) {
|
||||
dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base);
|
||||
}
|
||||
}
|
||||
#end(SHADER)
|
||||
@@ -2294,6 +2294,79 @@ struct test_rope_set_rows : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ROPE (+ GGML_OP_VIEW + GGML_OP_SET_ROWS)
|
||||
struct test_rms_norm_mul_rope : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float eps;
|
||||
const bool multi_add; // test a sequence of adds feeding into rms_norm
|
||||
const bool set_rows;
|
||||
int mode;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "RMS_NORM_MUL_ROPE";
|
||||
}
|
||||
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(ne, eps, multi_add, set_rows, mode);
|
||||
}
|
||||
|
||||
test_rms_norm_mul_rope(std::array<int64_t, 4> ne, float eps = 1e-6f, bool multi_add = false,
|
||||
bool set_rows = false, int mode = GGML_ROPE_TYPE_NORMAL)
|
||||
: ne(ne), eps(eps), multi_add(multi_add), set_rows(set_rows), mode(mode) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
|
||||
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
|
||||
ggml_tensor * c = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
|
||||
|
||||
if (multi_add) {
|
||||
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
|
||||
}
|
||||
|
||||
a = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
|
||||
|
||||
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
|
||||
|
||||
ggml_tensor * rope = ggml_rope(ctx, a, pos, ne[0], mode);
|
||||
|
||||
ggml_tensor * out;
|
||||
|
||||
if (set_rows) {
|
||||
ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);
|
||||
|
||||
ggml_tensor * dst = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, ne[0] * ne[1], ne[2] * ne[3], 1, 1);
|
||||
ggml_set_name(dst, "dst");
|
||||
|
||||
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, ne[2], 1, 1);
|
||||
ggml_set_name(row_idxs, "row_idxs");
|
||||
|
||||
out = ggml_set_rows(ctx, dst, view, row_idxs);
|
||||
ggml_set_name(out, "out");
|
||||
} else {
|
||||
out = rope;
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
init_set_rows_row_ids(t, ne[2]);
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ARGMAX
|
||||
struct test_argmax : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -3484,6 +3557,27 @@ struct test_mul_mat : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
static void init_mul_mat_id_tensors(ggml_context * ctx, int n_mats) {
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) { continue; }
|
||||
// ids
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<int32_t> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i % n_mats;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
|
||||
}
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GGML_OP_MUL_MAT_ID
|
||||
struct test_mul_mat_id : public test_case {
|
||||
const ggml_type type_a;
|
||||
@@ -3494,10 +3588,9 @@ struct test_mul_mat_id : public test_case {
|
||||
const int64_t m;
|
||||
const int64_t n;
|
||||
const int64_t k;
|
||||
const uint32_t o; // number of outputs
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR9(type_a, type_b, n_mats, n_used, b, m, n, k, o);
|
||||
return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
@@ -3511,9 +3604,69 @@ struct test_mul_mat_id : public test_case {
|
||||
|
||||
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
|
||||
int n_mats = 8, int n_used = 2, bool b = false,
|
||||
int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1)
|
||||
int64_t m = 32, int64_t n = 32, int64_t k = 32)
|
||||
: type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
|
||||
m(m), n(n), k(k), o(o) {
|
||||
m(m), n(n), k(k) {
|
||||
GGML_ASSERT(n_used <= n_mats);
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
|
||||
ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
|
||||
ggml_set_name(as, "as");
|
||||
|
||||
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
|
||||
ggml_set_name(ids, "ids");
|
||||
if (n_used != n_mats) {
|
||||
ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
|
||||
ggml_set_name(ids, "view_of_ids");
|
||||
}
|
||||
|
||||
ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
|
||||
ggml_set_name(b, "b");
|
||||
|
||||
ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
init_mul_mat_id_tensors(ctx, n_mats);
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_MUL_MAT_ID + GGML_OP_ADD or GGML_OP_MUL
|
||||
struct test_mul_mat_id_fusion : public test_case {
|
||||
const ggml_type type_a;
|
||||
const ggml_type type_b;
|
||||
const int n_mats;
|
||||
const int n_used;
|
||||
const bool b; // broadcast b matrix
|
||||
const int64_t m;
|
||||
const int64_t n;
|
||||
const int64_t k;
|
||||
const uint32_t o; // number of outputs
|
||||
const bool mul;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR10(type_a, type_b, n_mats, n_used, b, m, n, k, o, mul);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 5e-4;
|
||||
}
|
||||
|
||||
uint64_t op_flops(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return 2 * m * k * n * n_used;
|
||||
}
|
||||
|
||||
test_mul_mat_id_fusion(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
|
||||
int n_mats = 8, int n_used = 2, bool b = false,
|
||||
int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1, bool mul = false)
|
||||
: type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
|
||||
m(m), n(n), k(k), o(o), mul(mul) {
|
||||
GGML_ASSERT(n_used <= n_mats);
|
||||
}
|
||||
|
||||
@@ -3542,35 +3695,25 @@ struct test_mul_mat_id : public test_case {
|
||||
out = ggml_add(ctx, out, out2);
|
||||
}
|
||||
|
||||
if (mul) {
|
||||
std::array<int64_t, 4> ne { 1, out->ne[1], out->ne[2], out->ne[3] };
|
||||
ne[0] = 1;
|
||||
ggml_tensor * m = ggml_new_tensor(ctx, out->type, 4, ne.data());
|
||||
out = ggml_mul(ctx, out, m);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) { continue; }
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
// ids
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<int32_t> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i % n_mats;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
|
||||
}
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
init_mul_mat_id_tensors(ctx, n_mats);
|
||||
}
|
||||
|
||||
bool run_whole_graph() override { return o > 1; }
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return ggml_op_name(GGML_OP_MUL_MAT_ID);
|
||||
return "MUL_MAT_ID_FUSION";
|
||||
}
|
||||
};
|
||||
|
||||
@@ -4809,60 +4952,6 @@ struct test_topk_moe: public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
struct test_moe_expert_reduce : public test_case {
|
||||
const int64_t n_embd;
|
||||
const int64_t n_tokens;
|
||||
const int64_t n_expert_used;
|
||||
|
||||
test_moe_expert_reduce(int64_t n_embd = 64, int64_t n_tokens = 5, int64_t n_expert_used = 4)
|
||||
: n_embd(n_embd), n_tokens(n_tokens), n_expert_used(n_expert_used) {
|
||||
GGML_ASSERT(n_expert_used > 1);
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(n_embd, n_tokens, n_expert_used);
|
||||
}
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "MOE_EXPERT_REDUCE";
|
||||
}
|
||||
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * experts = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, n_expert_used, n_tokens);
|
||||
ggml_set_name(experts, "experts");
|
||||
|
||||
ggml_tensor * weights = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 1, n_expert_used, n_tokens);
|
||||
ggml_set_name(weights, "weights");
|
||||
|
||||
ggml_tensor * weighted = ggml_mul(ctx, experts, weights);
|
||||
ggml_set_name(weighted, "weighted_experts");
|
||||
|
||||
std::vector<ggml_tensor *> expert_views(n_expert_used);
|
||||
for (int64_t i = 0; i < n_expert_used; ++i) {
|
||||
expert_views[i] = ggml_view_2d(ctx, weighted, n_embd, n_tokens, weighted->nb[2], i * weighted->nb[1]);
|
||||
|
||||
std::string name = "expert_view_" + std::to_string(i);
|
||||
ggml_set_name(expert_views[i], name.c_str());
|
||||
ggml_build_forward_expand(gf, expert_views[i]);
|
||||
}
|
||||
|
||||
ggml_tensor * moe_out = expert_views[0];
|
||||
for (int64_t i = 1; i < n_expert_used; ++i) {
|
||||
moe_out = ggml_add(ctx, moe_out, expert_views[i]);
|
||||
|
||||
std::string name = "expert_add_" + std::to_string(i - 1);
|
||||
ggml_set_name(moe_out, name.c_str());
|
||||
}
|
||||
|
||||
ggml_set_name(moe_out, "moe_out");
|
||||
|
||||
return moe_out;
|
||||
}
|
||||
};
|
||||
|
||||
struct test_mul_mat_vec_fusion : public test_case {
|
||||
const ggml_type type;
|
||||
const ggml_glu_op glu_op;
|
||||
@@ -4911,8 +5000,10 @@ struct test_mul_mat_vec_fusion : public test_case {
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
if (!use_id) {
|
||||
std::array<int64_t, 4> ne = {k, m, 1, 1};
|
||||
std::array<int64_t, 4> ne0 = {k, n, 1, 1};
|
||||
const int channels = 4;
|
||||
const int samples = 2;
|
||||
std::array<int64_t, 4> ne = { k, m, channels, samples };
|
||||
std::array<int64_t, 4> ne0 = { k, n, channels, samples };
|
||||
|
||||
ggml_tensor * cur = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||
ggml_tensor * gate = with_gate ? ggml_new_tensor(ctx, type, 4, ne0.data()) : nullptr;
|
||||
@@ -4920,14 +5011,14 @@ struct test_mul_mat_vec_fusion : public test_case {
|
||||
|
||||
ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur);
|
||||
if (with_bias) {
|
||||
std::array<int64_t, 4> bias_ne = {ffn_up->ne[0], 1, 1, 1};
|
||||
std::array<int64_t, 4> bias_ne = { ffn_up->ne[0], 1, channels, samples };
|
||||
ggml_tensor * up_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
|
||||
ffn_up = ggml_add(ctx, ffn_up, up_bias);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_gate = with_gate ? ggml_mul_mat(ctx, gate, cur) : nullptr;
|
||||
if (with_bias && with_gate) {
|
||||
std::array<int64_t, 4> bias_ne = {ffn_gate->ne[0], 1, 1, 1};
|
||||
std::array<int64_t, 4> bias_ne = { ffn_gate->ne[0], 1, channels, samples };
|
||||
ggml_tensor * gate_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
|
||||
ffn_gate = ggml_add(ctx, ffn_gate, gate_bias);
|
||||
}
|
||||
@@ -4971,24 +5062,7 @@ struct test_mul_mat_vec_fusion : public test_case {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
} else {
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) { continue; }
|
||||
// ids
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<int32_t> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i % n_mats;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
|
||||
}
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
init_mul_mat_id_tensors(ctx, n_mats);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6648,6 +6722,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
|
||||
|
||||
test_cases.emplace_back(new test_cont());
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
|
||||
@@ -6750,6 +6825,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
for (auto multi_add : {false, true}) {
|
||||
for (auto set_rows : {false, true}) {
|
||||
for (auto rope : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX}) {
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 1, 1, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 1, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 5, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 50, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 50, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||
|
||||
for (int64_t d_conv : {3, 4}) {
|
||||
@@ -6941,7 +7032,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
|
||||
test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
|
||||
|
||||
// gpt-oss issue with Vulkan mmq_id
|
||||
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
|
||||
@@ -6978,6 +7069,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
for (int bs : {1, 4, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q4_K}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
// test with mul after (ffn_moe_weighted)
|
||||
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1, true));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (ggml_type type_a : base_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
for (int n : {1, 16}) {
|
||||
@@ -7323,10 +7423,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
|
||||
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
|
||||
|
||||
test_cases.emplace_back(new test_moe_expert_reduce(1024, 5, 4));
|
||||
test_cases.emplace_back(new test_moe_expert_reduce(80, 3, 6));
|
||||
test_cases.emplace_back(new test_moe_expert_reduce(80, 3, 7));
|
||||
|
||||
#if 0
|
||||
// these tests are disabled to save execution time, sbut they can be handy for debugging
|
||||
test_cases.emplace_back(new test_llama(2, true));
|
||||
@@ -7438,7 +7534,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
|
||||
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7446,7 +7542,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
|
||||
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7456,7 +7552,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
for (int bs : {1, 4, 8, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
|
||||
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -512,7 +512,7 @@ These words will not be included in the completion, so make sure to add them to
|
||||
|
||||
`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`
|
||||
|
||||
`return_progress`: Include prompt processing progress in `stream` mode. The progress will be contained inside `prompt_progress` with 3 values: `total`, `cache` and `processed`. The overall progress is `processed/total`, while the actual timed progress is `(processed-cache)/(total-cache)`. Default: `false`
|
||||
`return_progress`: Include prompt processing progress in `stream` mode. The progress will be contained inside `prompt_progress` with 4 values: `total`, `cache`, `processed`, and `time_ms`. The overall progress is `processed/total`, while the actual timed progress is `(processed-cache)/(total-cache)`. The `time_ms` field contains the elapsed time in milliseconds since prompt processing started. Default: `false`
|
||||
|
||||
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
|
||||
|
||||
|
||||
Binary file not shown.
@@ -3078,7 +3078,7 @@ struct server_context {
|
||||
res->progress.total = slot.task->n_tokens();
|
||||
res->progress.cache = slot.n_prompt_tokens_cache;
|
||||
res->progress.processed = slot.prompt.tokens.size();
|
||||
res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000);
|
||||
res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000;
|
||||
} else {
|
||||
res->content = tkn.text_to_send;
|
||||
res->tokens = { tkn.tok };
|
||||
|
||||
@@ -44,12 +44,12 @@
|
||||
}
|
||||
}
|
||||
|
||||
if (isCtrlOrCmd && event.shiftKey && event.key === 'o') {
|
||||
if (isCtrlOrCmd && event.shiftKey && event.key === 'O') {
|
||||
event.preventDefault();
|
||||
goto('?new_chat=true#/');
|
||||
}
|
||||
|
||||
if (event.shiftKey && isCtrlOrCmd && event.key === 'e') {
|
||||
if (event.shiftKey && isCtrlOrCmd && event.key === 'E') {
|
||||
event.preventDefault();
|
||||
|
||||
if (chatSidebar?.editActiveConversation) {
|
||||
|
||||
Reference in New Issue
Block a user