mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-03 07:34:07 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5f8812f7c | ||
|
|
8ece3836b4 | ||
|
|
046d5fd44e | ||
|
|
480160d472 | ||
|
|
15bff84bf5 |
44
.github/workflows/build.yml
vendored
44
.github/workflows/build.yml
vendored
@@ -152,13 +152,13 @@ jobs:
|
||||
DAWN_VERSION="v2.0.0"
|
||||
DAWN_OWNER="reeselevine"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
curl -L -o artifact.zip \
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
mkdir dawn
|
||||
unzip artifact.zip
|
||||
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1
|
||||
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -532,13 +532,13 @@ jobs:
|
||||
DAWN_VERSION="v2.0.0"
|
||||
DAWN_OWNER="reeselevine"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
curl -L -o artifact.zip \
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
mkdir dawn
|
||||
unzip artifact.zip
|
||||
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1
|
||||
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -1704,6 +1704,34 @@ jobs:
|
||||
run: |
|
||||
GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
|
||||
ggml-ci-mac-webgpu:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
run: |
|
||||
DAWN_VERSION="v2.0.0"
|
||||
DAWN_OWNER="reeselevine"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
curl -L -o artifact.zip \
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
mkdir dawn
|
||||
unzip artifact.zip
|
||||
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \
|
||||
bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
|
||||
ggml-ci-mac-vulkan:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
|
||||
15
ci/run.sh
15
ci/run.sh
@@ -105,7 +105,20 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_WEBGPU} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1 -DGGML_METAL=OFF -DGGML_BLAS=OFF"
|
||||
|
||||
if [ ! -z "${GG_BUILD_WEBGPU_DAWN_PREFIX}" ]; then
|
||||
if [ -z "${CMAKE_PREFIX_PATH}" ]; then
|
||||
export CMAKE_PREFIX_PATH="${GG_BUILD_WEBGPU_DAWN_PREFIX}"
|
||||
else
|
||||
export CMAKE_PREFIX_PATH="${GG_BUILD_WEBGPU_DAWN_PREFIX}:${CMAKE_PREFIX_PATH}"
|
||||
fi
|
||||
fi
|
||||
|
||||
# For some systems, Dawn_DIR needs to be set explicitly, e.g., the lib64 path
|
||||
if [ ! -z "${GG_BUILD_WEBGPU_DAWN_DIR}" ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DDawn_DIR=${GG_BUILD_WEBGPU_DAWN_DIR}"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_MUSA} ]; then
|
||||
|
||||
167
common/arg.cpp
167
common/arg.cpp
@@ -6,6 +6,7 @@
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "download.h"
|
||||
#include "preset.h"
|
||||
|
||||
// fix problem with std::min and std::max
|
||||
#if defined(_WIN32)
|
||||
@@ -268,6 +269,46 @@ static void parse_tensor_buffer_overrides(const std::string & value, std::vector
|
||||
}
|
||||
}
|
||||
|
||||
static std::string clean_file_name(const std::string & fname) {
|
||||
std::string clean_fname = fname;
|
||||
string_replace_all(clean_fname, "\\", "_");
|
||||
string_replace_all(clean_fname, "/", "_");
|
||||
return clean_fname;
|
||||
}
|
||||
|
||||
static bool common_params_handle_remote_preset(common_params & params, llama_example ex) {
|
||||
GGML_ASSERT(!params.model.hf_repo.empty());
|
||||
|
||||
const bool offline = params.offline;
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
auto preset_url = model_endpoint + params.model.hf_repo + "/resolve/main/preset.ini";
|
||||
|
||||
// prepare local path for caching
|
||||
auto preset_fname = clean_file_name(params.model.hf_repo + "_preset.ini");
|
||||
auto preset_path = fs_get_cache_file(preset_fname);
|
||||
const int status = common_download_file_single(preset_url, preset_path, params.hf_token, offline);
|
||||
const bool has_preset = status >= 200 && status < 400;
|
||||
|
||||
// remote preset is optional, so we don't error out if not found
|
||||
if (has_preset) {
|
||||
LOG_INF("applying remote preset from %s\n", preset_url.c_str());
|
||||
common_preset_context ctx(ex, /* only_remote_allowed */ true);
|
||||
common_preset global; // unused for now
|
||||
auto remote_presets = ctx.load_from_ini(preset_path, global);
|
||||
if (remote_presets.find(COMMON_PRESET_DEFAULT_NAME) != remote_presets.end()) {
|
||||
common_preset & preset = remote_presets.at(COMMON_PRESET_DEFAULT_NAME);
|
||||
LOG_INF("\n%s", preset.to_ini().c_str()); // to_ini already added trailing newline
|
||||
preset.apply_to_params(params);
|
||||
} else {
|
||||
throw std::runtime_error("Remote preset.ini does not contain [" + std::string(COMMON_PRESET_DEFAULT_NAME) + "] section");
|
||||
}
|
||||
} else {
|
||||
LOG_INF("%s", "no remote preset found, skipping\n");
|
||||
}
|
||||
|
||||
return has_preset;
|
||||
}
|
||||
|
||||
struct handle_model_result {
|
||||
bool found_mmproj = false;
|
||||
common_params_model mmproj;
|
||||
@@ -309,9 +350,7 @@ static handle_model_result common_params_handle_model(
|
||||
// make sure model path is present (for caching purposes)
|
||||
if (model.path.empty()) {
|
||||
// this is to avoid different repo having same file name, or same file name in different subdirs
|
||||
std::string filename = model.hf_repo + "_" + model.hf_file;
|
||||
// to make sure we don't have any slashes in the filename
|
||||
string_replace_all(filename, "/", "_");
|
||||
std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file);
|
||||
model.path = fs_get_cache_file(filename);
|
||||
}
|
||||
|
||||
@@ -425,61 +464,87 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
}
|
||||
};
|
||||
|
||||
std::set<std::string> seen_args;
|
||||
auto parse_cli_args = [&]() {
|
||||
std::set<std::string> seen_args;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
const std::string arg_prefix = "--";
|
||||
for (int i = 1; i < argc; i++) {
|
||||
const std::string arg_prefix = "--";
|
||||
|
||||
std::string arg = argv[i];
|
||||
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||
}
|
||||
if (arg_to_options.find(arg) == arg_to_options.end()) {
|
||||
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
||||
}
|
||||
if (!seen_args.insert(arg).second) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
}
|
||||
auto & tmp = arg_to_options[arg];
|
||||
auto opt = *tmp.first;
|
||||
bool is_positive = tmp.second;
|
||||
if (opt.has_value_from_env()) {
|
||||
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str());
|
||||
}
|
||||
try {
|
||||
if (opt.handler_void) {
|
||||
opt.handler_void(params);
|
||||
continue;
|
||||
std::string arg = argv[i];
|
||||
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||
}
|
||||
if (opt.handler_bool) {
|
||||
opt.handler_bool(params, is_positive);
|
||||
continue;
|
||||
if (arg_to_options.find(arg) == arg_to_options.end()) {
|
||||
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
||||
}
|
||||
if (!seen_args.insert(arg).second) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
}
|
||||
auto & tmp = arg_to_options[arg];
|
||||
auto opt = *tmp.first;
|
||||
bool is_positive = tmp.second;
|
||||
if (opt.has_value_from_env()) {
|
||||
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str());
|
||||
}
|
||||
try {
|
||||
if (opt.handler_void) {
|
||||
opt.handler_void(params);
|
||||
continue;
|
||||
}
|
||||
if (opt.handler_bool) {
|
||||
opt.handler_bool(params, is_positive);
|
||||
continue;
|
||||
}
|
||||
|
||||
// arg with single value
|
||||
check_arg(i);
|
||||
std::string val = argv[++i];
|
||||
if (opt.handler_int) {
|
||||
opt.handler_int(params, std::stoi(val));
|
||||
continue;
|
||||
}
|
||||
if (opt.handler_string) {
|
||||
opt.handler_string(params, val);
|
||||
continue;
|
||||
}
|
||||
// arg with single value
|
||||
check_arg(i);
|
||||
std::string val = argv[++i];
|
||||
if (opt.handler_int) {
|
||||
opt.handler_int(params, std::stoi(val));
|
||||
continue;
|
||||
}
|
||||
if (opt.handler_string) {
|
||||
opt.handler_string(params, val);
|
||||
continue;
|
||||
}
|
||||
|
||||
// arg with 2 values
|
||||
check_arg(i);
|
||||
std::string val2 = argv[++i];
|
||||
if (opt.handler_str_str) {
|
||||
opt.handler_str_str(params, val, val2);
|
||||
continue;
|
||||
// arg with 2 values
|
||||
check_arg(i);
|
||||
std::string val2 = argv[++i];
|
||||
if (opt.handler_str_str) {
|
||||
opt.handler_str_str(params, val, val2);
|
||||
continue;
|
||||
}
|
||||
} catch (std::exception & e) {
|
||||
throw std::invalid_argument(string_format(
|
||||
"error while handling argument \"%s\": %s\n\n"
|
||||
"usage:\n%s\n\nto show complete usage, run with -h",
|
||||
arg.c_str(), e.what(), opt.to_string().c_str()));
|
||||
}
|
||||
} catch (std::exception & e) {
|
||||
throw std::invalid_argument(string_format(
|
||||
"error while handling argument \"%s\": %s\n\n"
|
||||
"usage:\n%s\n\nto show complete usage, run with -h",
|
||||
arg.c_str(), e.what(), opt.to_string().c_str()));
|
||||
}
|
||||
};
|
||||
|
||||
// parse the first time to get -hf option (used for remote preset)
|
||||
parse_cli_args();
|
||||
|
||||
// maybe handle remote preset
|
||||
if (!params.model.hf_repo.empty()) {
|
||||
std::string cli_hf_repo = params.model.hf_repo;
|
||||
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
|
||||
|
||||
// special case: if hf_repo explicitly set by preset, we need to preserve it (ignore CLI value)
|
||||
// this is useful when we have one HF repo pointing to other HF repos (one model - multiple GGUFs)
|
||||
std::string preset_hf_repo = params.model.hf_repo;
|
||||
bool preset_has_hf_repo = preset_hf_repo != cli_hf_repo;
|
||||
|
||||
if (has_preset) {
|
||||
// re-parse CLI args to override preset values
|
||||
parse_cli_args();
|
||||
}
|
||||
|
||||
// preserve hf_repo from preset if needed
|
||||
if (preset_has_hf_repo) {
|
||||
params.model.hf_repo = preset_hf_repo;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -157,6 +157,10 @@ static std::string read_etag(const std::string & path) {
|
||||
return none;
|
||||
}
|
||||
|
||||
static bool is_http_status_ok(int status) {
|
||||
return status >= 200 && status < 400;
|
||||
}
|
||||
|
||||
#ifdef LLAMA_USE_CURL
|
||||
|
||||
//
|
||||
@@ -306,12 +310,14 @@ static bool common_download_head(CURL * curl,
|
||||
}
|
||||
|
||||
// download one single file from remote URL to local path
|
||||
static bool common_download_file_single_online(const std::string & url,
|
||||
// returns status code or -1 on error
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers) {
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
std::string etag;
|
||||
|
||||
@@ -371,7 +377,7 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -380,14 +386,14 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
if (std::filesystem::exists(path_temporary)) {
|
||||
if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::filesystem::exists(path)) {
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -414,23 +420,27 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
|
||||
long http_code = 0;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code < 200 || http_code >= 400) {
|
||||
|
||||
int status = static_cast<int>(http_code);
|
||||
if (!is_http_status_ok(http_code)) {
|
||||
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
|
||||
return false;
|
||||
return status; // TODO: maybe only return on certain codes
|
||||
}
|
||||
|
||||
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return static_cast<int>(http_code);
|
||||
} else {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
break;
|
||||
return 304; // Not Modified - fake cached response
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return -1; // max attempts reached
|
||||
}
|
||||
|
||||
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
|
||||
@@ -625,7 +635,8 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
}
|
||||
|
||||
// download one single file from remote URL to local path
|
||||
static bool common_download_file_single_online(const std::string & url,
|
||||
// returns status code or -1 on error
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers) {
|
||||
@@ -659,8 +670,10 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
|
||||
if (file_exists) {
|
||||
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
|
||||
return true;
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
return head->status; // cannot use cached file, return raw status code
|
||||
// TODO: maybe retry only on certain codes
|
||||
}
|
||||
|
||||
std::string etag;
|
||||
@@ -692,12 +705,12 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
if (file_exists) {
|
||||
if (!should_download_from_scratch) {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
return true;
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -709,7 +722,7 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
existing_size = std::filesystem::file_size(path_temporary);
|
||||
} else if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -730,15 +743,16 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
|
||||
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
if (!etag.empty()) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
break;
|
||||
|
||||
return head->status; // TODO: use actual GET status?
|
||||
}
|
||||
|
||||
return true;
|
||||
return -1; // max attempts reached
|
||||
}
|
||||
|
||||
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
|
||||
@@ -777,22 +791,22 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
|
||||
|
||||
#if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
static bool common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers) {
|
||||
int common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers) {
|
||||
if (!offline) {
|
||||
return common_download_file_single_online(url, path, bearer_token, headers);
|
||||
}
|
||||
|
||||
if (!std::filesystem::exists(path)) {
|
||||
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
return -1;
|
||||
}
|
||||
|
||||
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
|
||||
return true;
|
||||
return 304; // Not Modified - fake cached response
|
||||
}
|
||||
|
||||
// download multiple files from remote URLs to local paths
|
||||
@@ -810,7 +824,8 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
|
||||
std::async(
|
||||
std::launch::async,
|
||||
[&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
|
||||
return common_download_file_single(it.first, it.second, bearer_token, offline, headers);
|
||||
const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
|
||||
return is_http_status_ok(http_status);
|
||||
},
|
||||
item
|
||||
)
|
||||
@@ -837,7 +852,8 @@ bool common_download_model(const common_params_model & model,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!common_download_file_single(model.url, model.path, bearer_token, offline, headers)) {
|
||||
const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
|
||||
if (!is_http_status_ok(http_status)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -975,7 +991,7 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
|
||||
} else if (res_code == 401) {
|
||||
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
|
||||
} else {
|
||||
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
|
||||
throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
|
||||
}
|
||||
|
||||
// check response
|
||||
@@ -1094,7 +1110,8 @@ std::string common_docker_resolve_model(const std::string & docker) {
|
||||
std::string local_path = fs_get_cache_file(model_filename);
|
||||
|
||||
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
|
||||
if (!common_download_file_single(blob_url, local_path, token, false, {})) {
|
||||
const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
|
||||
if (!is_http_status_ok(http_status)) {
|
||||
throw std::runtime_error("Failed to download Docker Model");
|
||||
}
|
||||
|
||||
@@ -1120,6 +1137,14 @@ std::string common_docker_resolve_model(const std::string &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
int common_download_file_single(const std::string &,
|
||||
const std::string &,
|
||||
const std::string &,
|
||||
bool,
|
||||
const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
#endif // LLAMA_USE_CURL || LLAMA_USE_HTTPLIB
|
||||
|
||||
std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
|
||||
@@ -65,6 +65,14 @@ bool common_download_model(
|
||||
// returns list of cached models
|
||||
std::vector<common_cached_model_info> common_list_cached_models();
|
||||
|
||||
// download single file from url to local path
|
||||
// returns status code or -1 on error
|
||||
int common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers = {});
|
||||
|
||||
// resolve and download model from Docker registry
|
||||
// return local path to downloaded model file
|
||||
std::string common_docker_resolve_model(const std::string & docker);
|
||||
|
||||
@@ -16,6 +16,46 @@ static std::string rm_leading_dashes(const std::string & str) {
|
||||
return str.substr(pos);
|
||||
}
|
||||
|
||||
// only allow a subset of args for remote presets for security reasons
|
||||
// do not add more args unless absolutely necessary
|
||||
// args that output to files are strictly prohibited
|
||||
static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
|
||||
static const std::set<std::string> allowed_options = {
|
||||
"model-url",
|
||||
"hf-repo",
|
||||
"hf-repo-draft",
|
||||
"hf-repo-v", // vocoder
|
||||
"hf-file-v", // vocoder
|
||||
"mmproj-url",
|
||||
"pooling",
|
||||
"jinja",
|
||||
"batch-size",
|
||||
"ubatch-size",
|
||||
"cache-reuse",
|
||||
// note: sampling params are automatically allowed by default
|
||||
// negated args will be added automatically
|
||||
};
|
||||
|
||||
std::set<std::string> allowed_keys;
|
||||
|
||||
for (const auto & it : key_to_opt) {
|
||||
const std::string & key = it.first;
|
||||
const common_arg & opt = it.second;
|
||||
if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
|
||||
allowed_keys.insert(key);
|
||||
// also add variant keys (args without leading dashes and env vars)
|
||||
for (const auto & arg : opt.get_args()) {
|
||||
allowed_keys.insert(rm_leading_dashes(arg));
|
||||
}
|
||||
for (const auto & env : opt.get_env()) {
|
||||
allowed_keys.insert(env);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allowed_keys;
|
||||
}
|
||||
|
||||
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
|
||||
std::vector<std::string> args;
|
||||
|
||||
@@ -121,6 +161,29 @@ void common_preset::merge(const common_preset & other) {
|
||||
}
|
||||
}
|
||||
|
||||
void common_preset::apply_to_params(common_params & params) const {
|
||||
for (const auto & [opt, val] : options) {
|
||||
// apply each option to params
|
||||
if (opt.handler_string) {
|
||||
opt.handler_string(params, val);
|
||||
} else if (opt.handler_int) {
|
||||
opt.handler_int(params, std::stoi(val));
|
||||
} else if (opt.handler_bool) {
|
||||
opt.handler_bool(params, common_arg_utils::is_truthy(val));
|
||||
} else if (opt.handler_str_str) {
|
||||
// not supported yet
|
||||
throw std::runtime_error(string_format(
|
||||
"%s: option with two values is not supported yet",
|
||||
__func__
|
||||
));
|
||||
} else if (opt.handler_void) {
|
||||
opt.handler_void(params);
|
||||
} else {
|
||||
GGML_ABORT("unknown handler type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
|
||||
std::map<std::string, std::map<std::string, std::string>> parsed;
|
||||
|
||||
@@ -230,10 +293,16 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
|
||||
return value;
|
||||
}
|
||||
|
||||
common_preset_context::common_preset_context(llama_example ex)
|
||||
common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
|
||||
: ctx_params(common_params_parser_init(default_params, ex)) {
|
||||
common_params_add_preset_options(ctx_params.options);
|
||||
key_to_opt = get_map_key_opt(ctx_params);
|
||||
|
||||
// setup allowed keys if only_remote_allowed is true
|
||||
if (only_remote_allowed) {
|
||||
filter_allowed_keys = true;
|
||||
allowed_keys = get_remote_preset_whitelist(key_to_opt);
|
||||
}
|
||||
}
|
||||
|
||||
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
|
||||
@@ -250,6 +319,12 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
|
||||
LOG_DBG("loading preset: %s\n", preset.name.c_str());
|
||||
for (const auto & [key, value] : section.second) {
|
||||
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
|
||||
if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) {
|
||||
throw std::runtime_error(string_format(
|
||||
"option '%s' is not allowed in remote presets",
|
||||
key.c_str()
|
||||
));
|
||||
}
|
||||
if (key_to_opt.find(key) != key_to_opt.end()) {
|
||||
const auto & opt = key_to_opt.at(key);
|
||||
if (is_bool_arg(opt)) {
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
//
|
||||
// INI preset parser and writer
|
||||
@@ -40,6 +41,9 @@ struct common_preset {
|
||||
|
||||
// merge another preset into this one, overwriting existing options
|
||||
void merge(const common_preset & other);
|
||||
|
||||
// apply preset options to common_params
|
||||
void apply_to_params(common_params & params) const;
|
||||
};
|
||||
|
||||
// interface for multiple presets in one file
|
||||
@@ -50,7 +54,12 @@ struct common_preset_context {
|
||||
common_params default_params; // unused for now
|
||||
common_params_context ctx_params;
|
||||
std::map<std::string, common_arg> key_to_opt;
|
||||
common_preset_context(llama_example ex);
|
||||
|
||||
bool filter_allowed_keys = false;
|
||||
std::set<std::string> allowed_keys;
|
||||
|
||||
// if only_remote_allowed is true, only accept whitelisted keys
|
||||
common_preset_context(llama_example ex, bool only_remote_allowed = false);
|
||||
|
||||
// load presets from INI file
|
||||
common_presets load_from_ini(const std::string & path, common_preset & global) const;
|
||||
|
||||
60
docs/preset.md
Normal file
60
docs/preset.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# llama.cpp INI Presets
|
||||
|
||||
## Introduction
|
||||
|
||||
The INI preset feature, introduced in [PR#17859](https://github.com/ggml-org/llama.cpp/pull/17859), allows users to create reusable and shareable parameter configurations for llama.cpp.
|
||||
|
||||
### Using Presets with the Server
|
||||
|
||||
When running multiple models on the server (router mode), INI preset files can be used to configure model-specific parameters. Please refer to the [server documentation](../tools/server/README.md) for more details.
|
||||
|
||||
### Using a Remote Preset
|
||||
|
||||
> [!NOTE]
|
||||
>
|
||||
> This feature is currently only supported via the `-hf` option.
|
||||
|
||||
For GGUF models hosted on Hugging Face, you can include a `preset.ini` file in the root directory of the repository to define specific configurations for that model.
|
||||
|
||||
Example:
|
||||
|
||||
```ini
|
||||
hf-repo-draft = username/my-draft-model-GGUF
|
||||
temp = 0.5
|
||||
top-k = 20
|
||||
top-p = 0.95
|
||||
```
|
||||
|
||||
For security reasons, only certain options are allowed. Please refer to [preset.cpp](../common/preset.cpp) for the complete list of permitted options.
|
||||
|
||||
Example usage:
|
||||
|
||||
Assuming your repository `username/my-model-with-preset` contains a `preset.ini` with the configuration above:
|
||||
|
||||
```sh
|
||||
llama-cli -hf username/my-model-with-preset
|
||||
|
||||
# This is equivalent to:
|
||||
llama-cli -hf username/my-model-with-preset \
|
||||
--hf-repo-draft username/my-draft-model-GGUF \
|
||||
--temp 0.5 \
|
||||
--top-k 20 \
|
||||
--top-p 0.95
|
||||
```
|
||||
|
||||
You can also override preset arguments by specifying them on the command line:
|
||||
|
||||
```sh
|
||||
# Force temp = 0.1, overriding the preset value
|
||||
llama-cli -hf username/my-model-with-preset --temp 0.1
|
||||
```
|
||||
|
||||
If you want to define multiple preset configurations for one or more GGUF models, you can create a blank HF repo for each preset. Each HF repo should contain a `preset.ini` file that references the actual model(s):
|
||||
|
||||
```ini
|
||||
hf-repo = user/my-model-main
|
||||
hf-repo-draft = user/my-model-draft
|
||||
temp = 0.8
|
||||
ctx-size = 1024
|
||||
; (and other configurations)
|
||||
```
|
||||
@@ -234,6 +234,11 @@
|
||||
|
||||
#if UINTPTR_MAX == 0xFFFFFFFF
|
||||
#define GGML_MEM_ALIGN 4
|
||||
#elif defined(__EMSCRIPTEN__)
|
||||
// emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm.
|
||||
// (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.)
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18628
|
||||
#define GGML_MEM_ALIGN 8
|
||||
#else
|
||||
#define GGML_MEM_ALIGN 16
|
||||
#endif
|
||||
|
||||
@@ -144,7 +144,7 @@ extern "C" {
|
||||
// device description: short informative description of the device, could be the model name
|
||||
const char * (*get_description)(ggml_backend_dev_t dev);
|
||||
|
||||
// device memory in bytes
|
||||
// device memory in bytes: 0 bytes to indicate no memory to report
|
||||
void (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total);
|
||||
|
||||
// device type
|
||||
|
||||
@@ -4287,8 +4287,8 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_
|
||||
}
|
||||
|
||||
static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
*free = 1;
|
||||
*total = 1;
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
169
ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Normal file
169
ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Normal file
@@ -0,0 +1,169 @@
|
||||
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
|
||||
#define GGML_WEBGPU_SHADER_LIB_HPP
|
||||
|
||||
#include "ggml.h"
|
||||
#include "pre_wgsl.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#define GGML_WEBGPU_F16_SIZE_BYTES 2
|
||||
#define GGML_WEBGPU_F32_SIZE_BYTES 4
|
||||
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
|
||||
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
|
||||
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
|
||||
#define GGML_WEBGPU_KV_SEQ_PAD 256u
|
||||
|
||||
struct ggml_webgpu_flash_attn_shader_lib_context {
|
||||
ggml_type kv_type;
|
||||
uint32_t head_dim_qk;
|
||||
uint32_t head_dim_v;
|
||||
bool kv_direct;
|
||||
bool has_mask;
|
||||
bool has_sinks;
|
||||
bool uses_logit_softcap;
|
||||
uint32_t sg_mat_m;
|
||||
uint32_t sg_mat_n;
|
||||
uint32_t sg_mat_k;
|
||||
size_t wg_mem_limit_bytes;
|
||||
uint32_t max_subgroup_size;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_shader_decisions {
|
||||
uint32_t q_tile = 0;
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_processed_shader {
|
||||
std::string wgsl;
|
||||
std::string variant;
|
||||
ggml_webgpu_flash_attn_shader_decisions decisions;
|
||||
};
|
||||
|
||||
// This is exposed because it's necessary in supports_op
|
||||
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
||||
uint32_t kv_tile,
|
||||
uint32_t head_dim_qk,
|
||||
uint32_t head_dim_v,
|
||||
bool has_mask,
|
||||
bool kv_direct) {
|
||||
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
|
||||
size_t f16_elems = 0;
|
||||
size_t f32_elems = 0;
|
||||
f16_elems += q_tile * head_dim_qk; // q_shmem
|
||||
if (!kv_direct) {
|
||||
f16_elems += kv_tile * max_head_dim; // kv_shmem
|
||||
}
|
||||
f16_elems += q_tile * head_dim_v; // o_shmem
|
||||
if (has_mask) {
|
||||
f16_elems += q_tile * kv_tile; // mask_shmem
|
||||
}
|
||||
f16_elems += q_tile * kv_tile; // inter_shmem
|
||||
f32_elems += q_tile; // row_max_shmem
|
||||
f32_elems += q_tile; // exp_sum_shmem
|
||||
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
}
|
||||
|
||||
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
|
||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||
const size_t q_tile = context.sg_mat_m;
|
||||
const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
size_t bytes_per_kv = 0;
|
||||
if (!context.kv_direct) {
|
||||
bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
|
||||
}
|
||||
if (context.has_mask) {
|
||||
bytes_per_kv += q_tile;
|
||||
}
|
||||
bytes_per_kv += q_tile;
|
||||
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
||||
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
||||
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
||||
}
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_flash_attn_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn";
|
||||
|
||||
switch (context.kv_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("KV_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("KV_F16");
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
defines.push_back("KV_Q4_0");
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
defines.push_back("KV_Q8_0");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported KV type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_") + ggml_type_name(context.kv_type);
|
||||
|
||||
if (context.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
variant += "_mask";
|
||||
}
|
||||
if (context.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
variant += "_sinks";
|
||||
}
|
||||
if (context.uses_logit_softcap) {
|
||||
defines.push_back("LOGIT_SOFTCAP");
|
||||
variant += "_lgsc";
|
||||
}
|
||||
|
||||
if (context.kv_direct) {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(context.head_dim_v);
|
||||
|
||||
// For now these are not part of the variant name
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
|
||||
// Add chosen Q/KV tile sizes
|
||||
uint32_t q_tile = context.sg_mat_m;
|
||||
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
|
||||
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
if (context.kv_direct) {
|
||||
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
// Avoids having to use bounds-checks and decreasing performance for direct KV loads
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= context.sg_mat_n;
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
|
||||
|
||||
// workgroup size
|
||||
uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
result.decisions.q_tile = q_tile;
|
||||
result.decisions.kv_tile = kv_tile;
|
||||
result.decisions.wg_size = wg_size;
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif // GGML_WEBGPU_SHADER_LIB_HPP
|
||||
@@ -7,7 +7,9 @@
|
||||
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-webgpu-shader-lib.hpp"
|
||||
#include "ggml-wgsl-shaders.hpp"
|
||||
#include "pre_wgsl.hpp"
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
# include <emscripten/emscripten.h>
|
||||
@@ -30,7 +32,7 @@
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
|
||||
# define WEBGPU_DEBUG_BUF_ELEMS 32
|
||||
# define WEBGPU_DEBUG_BUF_ELEMS 512
|
||||
#else
|
||||
# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
|
||||
#endif // GGML_WEBGPU_DEBUG
|
||||
@@ -251,6 +253,7 @@ struct webgpu_gpu_profile_buf_pool {
|
||||
struct webgpu_pipeline {
|
||||
wgpu::ComputePipeline pipeline;
|
||||
std::string name;
|
||||
void * context = nullptr;
|
||||
};
|
||||
|
||||
struct webgpu_command {
|
||||
@@ -263,6 +266,46 @@ struct webgpu_command {
|
||||
#endif
|
||||
};
|
||||
|
||||
struct flash_attn_pipeline_key {
|
||||
int q_type;
|
||||
int kv_type;
|
||||
int dst_type;
|
||||
uint32_t head_dim_qk;
|
||||
uint32_t head_dim_v;
|
||||
bool kv_direct;
|
||||
bool has_mask;
|
||||
bool has_sinks;
|
||||
bool uses_logit_softcap;
|
||||
|
||||
bool operator==(const flash_attn_pipeline_key & other) const {
|
||||
return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
|
||||
head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
|
||||
has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
||||
uses_logit_softcap == other.uses_logit_softcap;
|
||||
}
|
||||
};
|
||||
|
||||
// Same hash combine function as in boost
|
||||
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
|
||||
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
struct flash_attn_pipeline_key_hash {
|
||||
size_t operator()(const flash_attn_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.q_type);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_type);
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
||||
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
// All the base objects needed to run operations on a WebGPU device
|
||||
struct webgpu_context_struct {
|
||||
wgpu::Instance instance;
|
||||
@@ -271,12 +314,12 @@ struct webgpu_context_struct {
|
||||
wgpu::Queue queue;
|
||||
wgpu::Limits limits;
|
||||
|
||||
uint32_t subgroup_size;
|
||||
uint32_t max_subgroup_size;
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
bool supports_subgroup_matrix = false;
|
||||
wgpu::SubgroupMatrixConfig subgroup_matrix_config;
|
||||
#endif
|
||||
bool supports_subgroup_matrix = false;
|
||||
uint32_t sg_mat_m;
|
||||
uint32_t sg_mat_n;
|
||||
uint32_t sg_mat_k;
|
||||
|
||||
std::recursive_mutex mutex;
|
||||
std::atomic_uint inflight_threads = 0;
|
||||
@@ -284,20 +327,24 @@ struct webgpu_context_struct {
|
||||
webgpu_buf_pool param_buf_pool;
|
||||
webgpu_buf_pool set_rows_error_buf_pool;
|
||||
|
||||
pre_wgsl::Preprocessor p;
|
||||
|
||||
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
|
||||
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
|
||||
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
|
||||
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
|
||||
std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines;
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
|
||||
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
|
||||
|
||||
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
|
||||
@@ -361,8 +408,6 @@ struct ggml_backend_webgpu_buffer_context {
|
||||
label(std::move(lbl)) {}
|
||||
};
|
||||
|
||||
/* End struct definitions */
|
||||
|
||||
/* WebGPU object initializations */
|
||||
|
||||
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
|
||||
@@ -484,14 +529,9 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
|
||||
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
||||
wgpu::CommandBuffer commands = encoder.Finish();
|
||||
ctx->queue.Submit(1, &commands);
|
||||
|
||||
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
|
||||
const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
|
||||
std::cout << "debug data:";
|
||||
for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
|
||||
std::cout << " " << i << ": " << debug_data[i];
|
||||
}
|
||||
std::cout << "\n";
|
||||
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
|
||||
std::cout << "debug[0]: " << debug_data[0] << "\n";
|
||||
ctx->debug_host_buf.Unmap();
|
||||
}
|
||||
#endif
|
||||
@@ -673,6 +713,7 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
|
||||
return ctx->name.c_str();
|
||||
}
|
||||
|
||||
// TODO: implement proper cleanup
|
||||
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
||||
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
|
||||
@@ -730,12 +771,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
||||
return ctx->buffer;
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
|
||||
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
|
||||
size_t offset = ggml_webgpu_tensor_offset(t);
|
||||
return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
|
||||
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
|
||||
size_t offset = ggml_webgpu_tensor_offset(t);
|
||||
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||
}
|
||||
@@ -964,12 +1005,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (ctx->supports_subgroup_matrix) {
|
||||
// The total number of subgroups/workgroups needed per matrix.
|
||||
uint32_t wg_m_sg_tile =
|
||||
WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M;
|
||||
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
|
||||
uint32_t wg_n_sg_tile =
|
||||
WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
|
||||
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
|
||||
uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
|
||||
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
|
||||
uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
|
||||
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
|
||||
} else {
|
||||
#endif
|
||||
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
@@ -986,6 +1025,146 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
ggml_tensor * Q,
|
||||
ggml_tensor * K,
|
||||
ggml_tensor * V,
|
||||
ggml_tensor * mask,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * dst) {
|
||||
float scale = *(float *) dst->op_params;
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
||||
if (logit_softcap != 0.0f) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
|
||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const int has_mask = (mask != nullptr);
|
||||
const int has_sinks = (sinks != nullptr);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
|
||||
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
|
||||
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) Q->ne[2], // number of heads
|
||||
(uint32_t) Q->ne[1], // sequence length (Q)
|
||||
(uint32_t) K->ne[1], // sequence length (K/V)
|
||||
(uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
|
||||
(uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
|
||||
(uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
|
||||
(uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
|
||||
(uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
|
||||
(uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
|
||||
(uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
|
||||
(uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
|
||||
(uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
|
||||
has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
|
||||
(uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
|
||||
*(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
|
||||
*(uint32_t *) &max_bias,
|
||||
*(uint32_t *) &logit_softcap,
|
||||
*(uint32_t *) &n_head_log2,
|
||||
*(uint32_t *) &m0,
|
||||
*(uint32_t *) &m1
|
||||
|
||||
};
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(Q),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, Q),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, Q) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(K),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, K),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, K) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(V),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, V),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, V) }
|
||||
};
|
||||
uint32_t binding_index = 3;
|
||||
if (has_mask) {
|
||||
entries.push_back({ .binding = binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(mask),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, mask) });
|
||||
}
|
||||
if (has_sinks) {
|
||||
entries.push_back({ .binding = binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(sinks),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
|
||||
}
|
||||
entries.push_back({ .binding = binding_index++,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
|
||||
bool kv_direct =
|
||||
(K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
|
||||
flash_attn_pipeline_key key = {
|
||||
.q_type = Q->type,
|
||||
.kv_type = K->type,
|
||||
.dst_type = dst->type,
|
||||
.head_dim_qk = (uint32_t) Q->ne[0],
|
||||
.head_dim_v = (uint32_t) V->ne[0],
|
||||
.kv_direct = kv_direct,
|
||||
.has_mask = static_cast<bool>(has_mask),
|
||||
.has_sinks = static_cast<bool>(has_sinks),
|
||||
.uses_logit_softcap = logit_softcap != 0.0f,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
ggml_webgpu_flash_attn_shader_decisions decisions = {};
|
||||
|
||||
auto it = ctx->flash_attn_pipelines.find(key);
|
||||
if (it != ctx->flash_attn_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
|
||||
} else {
|
||||
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
||||
it = ctx->flash_attn_pipelines.find(key);
|
||||
if (it != ctx->flash_attn_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
|
||||
} else {
|
||||
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type,
|
||||
.head_dim_qk = (uint32_t) Q->ne[0],
|
||||
.head_dim_v = (uint32_t) V->ne[0],
|
||||
.kv_direct = kv_direct,
|
||||
.has_mask = static_cast<bool>(has_mask),
|
||||
.has_sinks = static_cast<bool>(has_sinks),
|
||||
.uses_logit_softcap = logit_softcap != 0.0f,
|
||||
.sg_mat_m = ctx->sg_mat_m,
|
||||
.sg_mat_n = ctx->sg_mat_n,
|
||||
.sg_mat_k = ctx->sg_mat_k,
|
||||
.wg_mem_limit_bytes =
|
||||
ctx->limits.maxComputeWorkgroupStorageSize,
|
||||
.max_subgroup_size = ctx->max_subgroup_size };
|
||||
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
|
||||
pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
|
||||
ctx->flash_attn_pipelines.emplace(key, pipeline);
|
||||
decisions = processed.decisions;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
|
||||
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
ggml_unary_op unary_op = ggml_get_unary_op(dst);
|
||||
@@ -1397,6 +1576,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
||||
return ggml_webgpu_get_rows(ctx, src0, src1, node);
|
||||
case GGML_OP_MUL_MAT:
|
||||
return ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||
@@ -1466,6 +1647,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
|
||||
futures.push_back(new_futures);
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_wait(ctx, futures);
|
||||
ctx->inflight_threads--;
|
||||
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
|
||||
@@ -1808,15 +1990,15 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (webgpu_ctx->supports_subgroup_matrix) {
|
||||
std::map<std::string, std::string> sg_matrix_repls;
|
||||
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
|
||||
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
|
||||
sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
|
||||
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k);
|
||||
|
||||
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
|
||||
proc_mul_mat_f32_f32_vec =
|
||||
@@ -2328,6 +2510,7 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
|
||||
webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
|
||||
}
|
||||
|
||||
// TODO: move most initialization logic here
|
||||
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
GGML_UNUSED(params);
|
||||
|
||||
@@ -2489,6 +2672,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
if (!webgpu_ctx->supports_subgroup_matrix) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
||||
size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
|
||||
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
|
||||
has_mask, kv_direct);
|
||||
if (min_bytes > limit_bytes) {
|
||||
break;
|
||||
}
|
||||
|
||||
supports_op = src0->type == GGML_TYPE_F32 &&
|
||||
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
|
||||
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
|
||||
src2->type == src1->type && op->type == GGML_TYPE_F32;
|
||||
break;
|
||||
}
|
||||
case GGML_OP_RMS_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
@@ -2606,6 +2812,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||
}
|
||||
|
||||
// TODO: Does this need to be thread safe? Is it only called once?
|
||||
// TODO: move most logic to device_init function so backend can be freed/initialized properly
|
||||
// Only one device is supported for now
|
||||
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||
GGML_ASSERT(index == 0);
|
||||
@@ -2665,7 +2872,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
|
||||
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
||||
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
||||
ctx->subgroup_matrix_config = config;
|
||||
ctx->sg_mat_m = config.M;
|
||||
ctx->sg_mat_n = config.N;
|
||||
ctx->sg_mat_k = config.K;
|
||||
valid_subgroup_matrix_config = true;
|
||||
break;
|
||||
}
|
||||
@@ -2676,7 +2885,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
#endif
|
||||
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
||||
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
||||
ctx->subgroup_size = info.subgroupMaxSize;
|
||||
ctx->max_subgroup_size = info.subgroupMaxSize;
|
||||
|
||||
// Initialize device
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
|
||||
@@ -2701,8 +2910,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
GGML_UNUSED(device);
|
||||
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
||||
std::string(message).c_str());
|
||||
GGML_UNUSED(reason);
|
||||
GGML_UNUSED(message);
|
||||
//TODO: uncomment once proper free logic is in place
|
||||
//GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
||||
//std::string(message).c_str());
|
||||
});
|
||||
dev_desc.SetUncapturedErrorCallback(
|
||||
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|
||||
|
||||
778
ggml/src/ggml-webgpu/pre_wgsl.hpp
Normal file
778
ggml/src/ggml-webgpu/pre_wgsl.hpp
Normal file
@@ -0,0 +1,778 @@
|
||||
#ifndef PRE_WGSL_HPP
|
||||
#define PRE_WGSL_HPP
|
||||
|
||||
#include <cctype>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace pre_wgsl {
|
||||
|
||||
//==============================================================
|
||||
// Options
|
||||
//==============================================================
|
||||
struct Options {
|
||||
std::string include_path = ".";
|
||||
std::vector<std::string> macros;
|
||||
};
|
||||
|
||||
//==============================================================
|
||||
// Utility: trim
|
||||
//==============================================================
|
||||
static std::string trim(const std::string & s) {
|
||||
size_t a = 0;
|
||||
while (a < s.size() && std::isspace((unsigned char) s[a])) {
|
||||
a++;
|
||||
}
|
||||
size_t b = s.size();
|
||||
while (b > a && std::isspace((unsigned char) s[b - 1])) {
|
||||
b--;
|
||||
}
|
||||
return s.substr(a, b - a);
|
||||
}
|
||||
|
||||
static std::string trim_value(std::istream & is) {
|
||||
std::string str;
|
||||
std::getline(is, str);
|
||||
return trim(str);
|
||||
}
|
||||
|
||||
static bool isIdentChar(char c) {
|
||||
return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
|
||||
}
|
||||
|
||||
static std::string expandMacrosRecursiveInternal(const std::string & line,
|
||||
const std::unordered_map<std::string, std::string> & macros,
|
||||
std::unordered_set<std::string> & visiting);
|
||||
|
||||
static std::string expandMacroValue(const std::string & name,
|
||||
const std::unordered_map<std::string, std::string> & macros,
|
||||
std::unordered_set<std::string> & visiting) {
|
||||
if (visiting.count(name)) {
|
||||
throw std::runtime_error("Recursive macro: " + name);
|
||||
}
|
||||
visiting.insert(name);
|
||||
|
||||
auto it = macros.find(name);
|
||||
if (it == macros.end()) {
|
||||
visiting.erase(name);
|
||||
return name;
|
||||
}
|
||||
|
||||
const std::string & value = it->second;
|
||||
if (value.empty()) {
|
||||
visiting.erase(name);
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting);
|
||||
visiting.erase(name);
|
||||
return expanded;
|
||||
}
|
||||
|
||||
static std::string expandMacrosRecursiveInternal(const std::string & line,
|
||||
const std::unordered_map<std::string, std::string> & macros,
|
||||
std::unordered_set<std::string> & visiting) {
|
||||
std::string result;
|
||||
result.reserve(line.size());
|
||||
|
||||
size_t i = 0;
|
||||
while (i < line.size()) {
|
||||
if (isIdentChar(line[i])) {
|
||||
size_t start = i;
|
||||
while (i < line.size() && isIdentChar(line[i])) {
|
||||
i++;
|
||||
}
|
||||
std::string token = line.substr(start, i - start);
|
||||
|
||||
auto it = macros.find(token);
|
||||
if (it != macros.end()) {
|
||||
result += expandMacroValue(token, macros, visiting);
|
||||
} else {
|
||||
result += token;
|
||||
}
|
||||
} else {
|
||||
result += line[i];
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string expandMacrosRecursive(const std::string & line,
|
||||
const std::unordered_map<std::string, std::string> & macros) {
|
||||
std::unordered_set<std::string> visiting;
|
||||
return expandMacrosRecursiveInternal(line, macros, visiting);
|
||||
}
|
||||
|
||||
//==============================================================
|
||||
// Tokenizer for expressions in #if/#elif
|
||||
//==============================================================
|
||||
class ExprLexer {
|
||||
public:
|
||||
enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN };
|
||||
|
||||
struct Tok {
|
||||
Kind kind;
|
||||
std::string text;
|
||||
};
|
||||
|
||||
explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {}
|
||||
|
||||
Tok next() {
|
||||
skipWS();
|
||||
if (pos >= src.size()) {
|
||||
return { END, "" };
|
||||
}
|
||||
|
||||
char c = src[pos];
|
||||
|
||||
// number
|
||||
if (std::isdigit((unsigned char) c)) {
|
||||
size_t start = pos;
|
||||
while (pos < src.size() && std::isdigit((unsigned char) src[pos])) {
|
||||
pos++;
|
||||
}
|
||||
return { NUMBER, std::string(src.substr(start, pos - start)) };
|
||||
}
|
||||
|
||||
// identifier
|
||||
if (std::isalpha((unsigned char) c) || c == '_') {
|
||||
size_t start = pos;
|
||||
while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) {
|
||||
pos++;
|
||||
}
|
||||
return { IDENT, std::string(src.substr(start, pos - start)) };
|
||||
}
|
||||
|
||||
if (c == '(') {
|
||||
pos++;
|
||||
return { LPAREN, "(" };
|
||||
}
|
||||
if (c == ')') {
|
||||
pos++;
|
||||
return { RPAREN, ")" };
|
||||
}
|
||||
|
||||
// multi-char operators
|
||||
static const char * two_ops[] = { "==", "!=", "<=", ">=", "&&", "||", "<<", ">>" };
|
||||
for (auto op : two_ops) {
|
||||
if (src.substr(pos, 2) == op) {
|
||||
pos += 2;
|
||||
return { OP, std::string(op) };
|
||||
}
|
||||
}
|
||||
|
||||
// single-char operators
|
||||
if (std::string("+-*/%<>!").find(c) != std::string::npos) {
|
||||
pos++;
|
||||
return { OP, std::string(1, c) };
|
||||
}
|
||||
|
||||
// unexpected
|
||||
pos++;
|
||||
return { END, "" };
|
||||
}
|
||||
|
||||
private:
|
||||
std::string_view src;
|
||||
size_t pos;
|
||||
|
||||
void skipWS() {
|
||||
while (pos < src.size() && std::isspace((unsigned char) src[pos])) {
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
//==============================================================
|
||||
// Expression Parser (recursive descent)
|
||||
//==============================================================
|
||||
class ExprParser {
|
||||
public:
|
||||
ExprParser(std::string_view expr,
|
||||
const std::unordered_map<std::string, std::string> & macros,
|
||||
std::unordered_set<std::string> & visiting) :
|
||||
lex(expr),
|
||||
macros(macros),
|
||||
visiting(visiting) {
|
||||
advance();
|
||||
}
|
||||
|
||||
int parse() { return parseLogicalOr(); }
|
||||
|
||||
private:
|
||||
ExprLexer lex;
|
||||
ExprLexer::Tok tok;
|
||||
const std::unordered_map<std::string, std::string> & macros;
|
||||
std::unordered_set<std::string> & visiting;
|
||||
|
||||
void advance() { tok = lex.next(); }
|
||||
|
||||
bool acceptOp(const std::string & s) {
|
||||
if (tok.kind == ExprLexer::OP && tok.text == s) {
|
||||
advance();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool acceptKind(ExprLexer::Kind k) {
|
||||
if (tok.kind == k) {
|
||||
advance();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int parseLogicalOr() {
|
||||
int v = parseLogicalAnd();
|
||||
while (acceptOp("||")) {
|
||||
int rhs = parseLogicalAnd();
|
||||
v = (v || rhs);
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
int parseLogicalAnd() {
|
||||
int v = parseEquality();
|
||||
while (acceptOp("&&")) {
|
||||
int rhs = parseEquality();
|
||||
v = (v && rhs);
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
int parseEquality() {
|
||||
int v = parseRelational();
|
||||
for (;;) {
|
||||
if (acceptOp("==")) {
|
||||
int rhs = parseRelational();
|
||||
v = (v == rhs);
|
||||
} else if (acceptOp("!=")) {
|
||||
int rhs = parseRelational();
|
||||
v = (v != rhs);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
int parseRelational() {
|
||||
int v = parseShift();
|
||||
for (;;) {
|
||||
if (acceptOp("<")) {
|
||||
int rhs = parseShift();
|
||||
v = (v < rhs);
|
||||
} else if (acceptOp(">")) {
|
||||
int rhs = parseShift();
|
||||
v = (v > rhs);
|
||||
} else if (acceptOp("<=")) {
|
||||
int rhs = parseShift();
|
||||
v = (v <= rhs);
|
||||
} else if (acceptOp(">=")) {
|
||||
int rhs = parseShift();
|
||||
v = (v >= rhs);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
int parseShift() {
|
||||
int v = parseAdd();
|
||||
for (;;) {
|
||||
if (acceptOp("<<")) {
|
||||
int rhs = parseAdd();
|
||||
v = (v << rhs);
|
||||
} else if (acceptOp(">>")) {
|
||||
int rhs = parseAdd();
|
||||
v = (v >> rhs);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
int parseAdd() {
|
||||
int v = parseMult();
|
||||
for (;;) {
|
||||
if (acceptOp("+")) {
|
||||
int rhs = parseMult();
|
||||
v = (v + rhs);
|
||||
} else if (acceptOp("-")) {
|
||||
int rhs = parseMult();
|
||||
v = (v - rhs);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
int parseMult() {
|
||||
int v = parseUnary();
|
||||
for (;;) {
|
||||
if (acceptOp("*")) {
|
||||
int rhs = parseUnary();
|
||||
v = (v * rhs);
|
||||
} else if (acceptOp("/")) {
|
||||
int rhs = parseUnary();
|
||||
v = (rhs == 0 ? 0 : v / rhs);
|
||||
} else if (acceptOp("%")) {
|
||||
int rhs = parseUnary();
|
||||
v = (rhs == 0 ? 0 : v % rhs);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
int parseUnary() {
|
||||
if (acceptOp("!")) {
|
||||
return !parseUnary();
|
||||
}
|
||||
if (acceptOp("-")) {
|
||||
return -parseUnary();
|
||||
}
|
||||
if (acceptOp("+")) {
|
||||
return +parseUnary();
|
||||
}
|
||||
return parsePrimary();
|
||||
}
|
||||
|
||||
int parsePrimary() {
|
||||
// '(' expr ')'
|
||||
if (acceptKind(ExprLexer::LPAREN)) {
|
||||
int v = parse();
|
||||
if (!acceptKind(ExprLexer::RPAREN)) {
|
||||
throw std::runtime_error("missing ')'");
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
// number
|
||||
if (tok.kind == ExprLexer::NUMBER) {
|
||||
int v = std::stoi(tok.text);
|
||||
advance();
|
||||
return v;
|
||||
}
|
||||
|
||||
// defined(identifier)
|
||||
if (tok.kind == ExprLexer::IDENT && tok.text == "defined") {
|
||||
advance();
|
||||
if (acceptKind(ExprLexer::LPAREN)) {
|
||||
if (tok.kind != ExprLexer::IDENT) {
|
||||
throw std::runtime_error("expected identifier in defined()");
|
||||
}
|
||||
std::string name = tok.text;
|
||||
advance();
|
||||
if (!acceptKind(ExprLexer::RPAREN)) {
|
||||
throw std::runtime_error("missing ) in defined()");
|
||||
}
|
||||
return macros.count(name) ? 1 : 0;
|
||||
} else {
|
||||
// defined NAME
|
||||
if (tok.kind != ExprLexer::IDENT) {
|
||||
throw std::runtime_error("expected identifier in defined NAME");
|
||||
}
|
||||
std::string name = tok.text;
|
||||
advance();
|
||||
return macros.count(name) ? 1 : 0;
|
||||
}
|
||||
}
|
||||
|
||||
// identifier -> treat as integer, if defined use its value else 0
|
||||
if (tok.kind == ExprLexer::IDENT) {
|
||||
std::string name = tok.text;
|
||||
advance();
|
||||
auto it = macros.find(name);
|
||||
if (it == macros.end()) {
|
||||
return 0;
|
||||
}
|
||||
if (it->second.empty()) {
|
||||
return 1;
|
||||
}
|
||||
return evalMacroExpression(name, it->second);
|
||||
}
|
||||
|
||||
// unexpected
|
||||
return 0;
|
||||
}
|
||||
|
||||
int evalMacroExpression(const std::string & name, const std::string & value) {
|
||||
if (visiting.count(name)) {
|
||||
throw std::runtime_error("Recursive macro: " + name);
|
||||
}
|
||||
|
||||
visiting.insert(name);
|
||||
ExprParser ep(value, macros, visiting);
|
||||
int v = ep.parse();
|
||||
visiting.erase(name);
|
||||
return v;
|
||||
}
|
||||
};
|
||||
|
||||
//==============================================================
|
||||
// Preprocessor
|
||||
//==============================================================
|
||||
class Preprocessor {
|
||||
public:
|
||||
explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) {
|
||||
// Treat empty include path as current directory
|
||||
if (opts_.include_path.empty()) {
|
||||
opts_.include_path = ".";
|
||||
}
|
||||
parseMacroDefinitions(opts_.macros);
|
||||
}
|
||||
|
||||
std::string preprocess_file(const std::string & filename, const std::vector<std::string> & additional_macros = {}) {
|
||||
std::unordered_map<std::string, std::string> macros;
|
||||
std::unordered_set<std::string> predefined;
|
||||
std::unordered_set<std::string> include_stack;
|
||||
buildMacros(additional_macros, macros, predefined);
|
||||
|
||||
std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string preprocess(const std::string & contents, const std::vector<std::string> & additional_macros = {}) {
|
||||
std::unordered_map<std::string, std::string> macros;
|
||||
std::unordered_set<std::string> predefined;
|
||||
std::unordered_set<std::string> include_stack;
|
||||
buildMacros(additional_macros, macros, predefined);
|
||||
|
||||
std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string preprocess_includes_file(const std::string & filename) {
|
||||
std::unordered_map<std::string, std::string> macros;
|
||||
std::unordered_set<std::string> predefined;
|
||||
std::unordered_set<std::string> include_stack;
|
||||
std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string preprocess_includes(const std::string & contents) {
|
||||
std::unordered_map<std::string, std::string> macros;
|
||||
std::unordered_set<std::string> predefined;
|
||||
std::unordered_set<std::string> include_stack;
|
||||
std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
Options opts_;
|
||||
std::unordered_map<std::string, std::string> global_macros;
|
||||
|
||||
enum class DirectiveMode { All, IncludesOnly };
|
||||
|
||||
struct Cond {
|
||||
bool parent_active;
|
||||
bool active;
|
||||
bool taken;
|
||||
};
|
||||
|
||||
//----------------------------------------------------------
|
||||
// Parse macro definitions into global_macros
|
||||
//----------------------------------------------------------
|
||||
void parseMacroDefinitions(const std::vector<std::string> & macro_defs) {
|
||||
for (const auto & def : macro_defs) {
|
||||
size_t eq_pos = def.find('=');
|
||||
if (eq_pos != std::string::npos) {
|
||||
// Format: NAME=VALUE
|
||||
std::string name = trim(def.substr(0, eq_pos));
|
||||
std::string value = trim(def.substr(eq_pos + 1));
|
||||
global_macros[name] = value;
|
||||
} else {
|
||||
// Format: NAME
|
||||
std::string name = trim(def);
|
||||
global_macros[name] = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//----------------------------------------------------------
|
||||
// Build combined macro map and predefined set for a preprocessing operation
|
||||
//----------------------------------------------------------
|
||||
void buildMacros(const std::vector<std::string> & additional_macros,
|
||||
std::unordered_map<std::string, std::string> & macros,
|
||||
std::unordered_set<std::string> & predefined) {
|
||||
macros = global_macros;
|
||||
predefined.clear();
|
||||
|
||||
for (const auto & [name, value] : global_macros) {
|
||||
predefined.insert(name);
|
||||
}
|
||||
|
||||
for (const auto & def : additional_macros) {
|
||||
size_t eq_pos = def.find('=');
|
||||
std::string name, value;
|
||||
if (eq_pos != std::string::npos) {
|
||||
name = trim(def.substr(0, eq_pos));
|
||||
value = trim(def.substr(eq_pos + 1));
|
||||
} else {
|
||||
name = trim(def);
|
||||
value = "";
|
||||
}
|
||||
|
||||
// Add to macros map (will override global if same name)
|
||||
macros[name] = value;
|
||||
predefined.insert(name);
|
||||
}
|
||||
}
|
||||
|
||||
//----------------------------------------------------------
|
||||
// Helpers
|
||||
//----------------------------------------------------------
|
||||
std::string loadFile(const std::string & fname) {
|
||||
std::ifstream f(fname);
|
||||
if (!f.is_open()) {
|
||||
throw std::runtime_error("Could not open file: " + fname);
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << f.rdbuf();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool condActive(const std::vector<Cond> & cond) const {
|
||||
if (cond.empty()) {
|
||||
return true;
|
||||
}
|
||||
return cond.back().active;
|
||||
}
|
||||
|
||||
//----------------------------------------------------------
|
||||
// Process a file
|
||||
//----------------------------------------------------------
|
||||
std::string processFile(const std::string & name,
|
||||
std::unordered_map<std::string, std::string> & macros,
|
||||
const std::unordered_set<std::string> & predefined_macros,
|
||||
std::unordered_set<std::string> & include_stack,
|
||||
DirectiveMode mode) {
|
||||
if (include_stack.count(name)) {
|
||||
throw std::runtime_error("Recursive include: " + name);
|
||||
}
|
||||
|
||||
include_stack.insert(name);
|
||||
std::string shader_code = loadFile(name);
|
||||
std::string out = processString(shader_code, macros, predefined_macros, include_stack, mode);
|
||||
include_stack.erase(name);
|
||||
return out;
|
||||
}
|
||||
|
||||
std::string processIncludeFile(const std::string & fname,
|
||||
std::unordered_map<std::string, std::string> & macros,
|
||||
const std::unordered_set<std::string> & predefined_macros,
|
||||
std::unordered_set<std::string> & include_stack,
|
||||
DirectiveMode mode) {
|
||||
std::string full_path = opts_.include_path + "/" + fname;
|
||||
return processFile(full_path, macros, predefined_macros, include_stack, mode);
|
||||
}
|
||||
|
||||
//----------------------------------------------------------
|
||||
// Process text
|
||||
//----------------------------------------------------------
|
||||
std::string processString(const std::string & shader_code,
|
||||
std::unordered_map<std::string, std::string> & macros,
|
||||
const std::unordered_set<std::string> & predefined_macros,
|
||||
std::unordered_set<std::string> & include_stack,
|
||||
DirectiveMode mode) {
|
||||
std::vector<Cond> cond; // Conditional stack for this shader
|
||||
std::stringstream out;
|
||||
std::istringstream in(shader_code);
|
||||
std::string line;
|
||||
|
||||
while (std::getline(in, line)) {
|
||||
std::string t = trim(line);
|
||||
|
||||
if (!t.empty() && t[0] == '#') {
|
||||
bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);
|
||||
if (mode == DirectiveMode::IncludesOnly && !handled) {
|
||||
out << line << "\n";
|
||||
}
|
||||
} else {
|
||||
if (mode == DirectiveMode::IncludesOnly) {
|
||||
out << line << "\n";
|
||||
} else if (condActive(cond)) {
|
||||
// Expand macros in the line before outputting
|
||||
std::string expanded = expandMacrosRecursive(line, macros);
|
||||
out << expanded << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mode == DirectiveMode::All && !cond.empty()) {
|
||||
throw std::runtime_error("Unclosed #if directive");
|
||||
}
|
||||
|
||||
return out.str();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------
|
||||
// Directive handler
|
||||
//----------------------------------------------------------
|
||||
bool handleDirective(const std::string & t,
|
||||
std::stringstream & out,
|
||||
std::unordered_map<std::string, std::string> & macros,
|
||||
const std::unordered_set<std::string> & predefined_macros,
|
||||
std::vector<Cond> & cond,
|
||||
std::unordered_set<std::string> & include_stack,
|
||||
DirectiveMode mode) {
|
||||
// split into tokens
|
||||
std::string body = t.substr(1);
|
||||
std::istringstream iss(body);
|
||||
std::string cmd;
|
||||
iss >> cmd;
|
||||
|
||||
if (cmd == "include") {
|
||||
if (mode == DirectiveMode::All && !condActive(cond)) {
|
||||
return true;
|
||||
}
|
||||
std::string file;
|
||||
iss >> file;
|
||||
if (file.size() >= 2 && file.front() == '"' && file.back() == '"') {
|
||||
file = file.substr(1, file.size() - 2);
|
||||
}
|
||||
out << processIncludeFile(file, macros, predefined_macros, include_stack, mode);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (mode == DirectiveMode::IncludesOnly) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cmd == "define") {
|
||||
if (!condActive(cond)) {
|
||||
return true;
|
||||
}
|
||||
std::string name;
|
||||
iss >> name;
|
||||
// Don't override predefined macros from options
|
||||
if (predefined_macros.count(name)) {
|
||||
return true;
|
||||
}
|
||||
std::string value = trim_value(iss);
|
||||
macros[name] = value;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cmd == "undef") {
|
||||
if (!condActive(cond)) {
|
||||
return true;
|
||||
}
|
||||
std::string name;
|
||||
iss >> name;
|
||||
// Don't undef predefined macros from options
|
||||
if (predefined_macros.count(name)) {
|
||||
return true;
|
||||
}
|
||||
macros.erase(name);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cmd == "ifdef") {
|
||||
std::string name;
|
||||
iss >> name;
|
||||
bool p = condActive(cond);
|
||||
bool v = macros.count(name);
|
||||
cond.push_back({ p, p && v, p && v });
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cmd == "ifndef") {
|
||||
std::string name;
|
||||
iss >> name;
|
||||
bool p = condActive(cond);
|
||||
bool v = !macros.count(name);
|
||||
cond.push_back({ p, p && v, p && v });
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cmd == "if") {
|
||||
std::string expr = trim_value(iss);
|
||||
bool p = condActive(cond);
|
||||
bool v = false;
|
||||
if (p) {
|
||||
std::unordered_set<std::string> visiting;
|
||||
ExprParser ep(expr, macros, visiting);
|
||||
v = ep.parse() != 0;
|
||||
}
|
||||
cond.push_back({ p, p && v, p && v });
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cmd == "elif") {
|
||||
std::string expr = trim_value(iss);
|
||||
|
||||
if (cond.empty()) {
|
||||
throw std::runtime_error("#elif without #if");
|
||||
}
|
||||
|
||||
Cond & c = cond.back();
|
||||
if (!c.parent_active) {
|
||||
c.active = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (c.taken) {
|
||||
c.active = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> visiting;
|
||||
ExprParser ep(expr, macros, visiting);
|
||||
bool v = ep.parse() != 0;
|
||||
c.active = v;
|
||||
if (v) {
|
||||
c.taken = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cmd == "else") {
|
||||
if (cond.empty()) {
|
||||
throw std::runtime_error("#else without #if");
|
||||
}
|
||||
|
||||
Cond & c = cond.back();
|
||||
if (!c.parent_active) {
|
||||
c.active = false;
|
||||
return true;
|
||||
}
|
||||
if (c.taken) {
|
||||
c.active = false;
|
||||
} else {
|
||||
c.active = true;
|
||||
c.taken = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cmd == "endif") {
|
||||
if (cond.empty()) {
|
||||
throw std::runtime_error("#endif without #if");
|
||||
}
|
||||
cond.pop_back();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Unknown directive
|
||||
throw std::runtime_error("Unknown directive: #" + cmd);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace pre_wgsl
|
||||
|
||||
#endif // PRE_WGSL_HPP
|
||||
591
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl
Normal file
591
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl
Normal file
@@ -0,0 +1,591 @@
|
||||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
#else
|
||||
#define KV_TYPE f16
|
||||
#endif
|
||||
|
||||
// Default values
|
||||
#define HEAD_DIM_QK 64
|
||||
#define HEAD_DIM_V 64
|
||||
|
||||
// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN
|
||||
// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension.
|
||||
#define SG_MAT_M 8
|
||||
#define SG_MAT_N 8
|
||||
#define SG_MAT_K 8
|
||||
|
||||
// Each workgroup processes one subgroup matrix of Q rows
|
||||
#define Q_TILE SG_MAT_M
|
||||
#define KV_TILE 16
|
||||
#define WG_SIZE 64
|
||||
|
||||
// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE.
|
||||
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
|
||||
|
||||
// Quantization constants/helpers
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
// number of quantized elements processed per thread
|
||||
#if defined(KV_Q4_0)
|
||||
#define NQ 16
|
||||
// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
|
||||
#define F16_PER_BLOCK 9
|
||||
#define WEIGHTS_PER_F16 4
|
||||
#elif defined(KV_Q8_0)
|
||||
#define NQ 8
|
||||
// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
|
||||
#define F16_PER_BLOCK 17
|
||||
#define WEIGHTS_PER_F16 2
|
||||
#endif
|
||||
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
|
||||
|
||||
// Ok not to put these in a define block, compiler will remove if unused
|
||||
fn get_byte(value: u32, index: u32) -> u32 {
|
||||
return (value >> (index * 8)) & 0xFF;
|
||||
}
|
||||
|
||||
fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
|
||||
}
|
||||
|
||||
struct Params {
|
||||
offset_q: u32,
|
||||
offset_k: u32,
|
||||
offset_v: u32,
|
||||
offset_mask: u32,
|
||||
offset_sinks: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
// shapes of Q/K/V
|
||||
n_heads: u32,
|
||||
seq_len_q: u32,
|
||||
seq_len_kv: u32,
|
||||
|
||||
// strides (in elements)
|
||||
stride_q1: u32,
|
||||
stride_q2: u32,
|
||||
stride_q3: u32,
|
||||
stride_k1: u32,
|
||||
stride_k2: u32,
|
||||
stride_k3: u32,
|
||||
stride_v1: u32,
|
||||
stride_v2: u32,
|
||||
stride_v3: u32,
|
||||
stride_mask3: u32,
|
||||
|
||||
// repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
|
||||
q_per_kv: u32,
|
||||
|
||||
// softmax params
|
||||
scale: f32,
|
||||
max_bias: f32,
|
||||
logit_softcap: f32,
|
||||
n_head_log2: f32,
|
||||
m0: f32,
|
||||
m1: f32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
|
||||
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 5
|
||||
#define PARAMS_BINDING 6
|
||||
#elif defined(MASK)
|
||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#elif defined(SINKS)
|
||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||
#define DST_BINDING 4
|
||||
#define PARAMS_BINDING 5
|
||||
#else
|
||||
#define DST_BINDING 3
|
||||
#define PARAMS_BINDING 4
|
||||
#endif
|
||||
|
||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
|
||||
|
||||
// Just a very small float value.
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
||||
// The number of Q rows processed per workgroup
|
||||
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
|
||||
|
||||
#ifndef KV_DIRECT
|
||||
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
||||
// we can reuse the same shmem for K and V since we only need one at a time
|
||||
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
|
||||
#endif
|
||||
|
||||
var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>; // output shmem
|
||||
|
||||
#ifdef MASK
|
||||
// storage for mask values
|
||||
var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
#endif
|
||||
|
||||
// storage for output of Q*K^T scores for online softmax (S matrix from paper)
|
||||
// also storage for diagonal matrix during online softmax (P matrix from paper)
|
||||
// note that we reuse the same storage for both since we only need one at a time
|
||||
var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
|
||||
// Storage for row max and exp sum during online softmax
|
||||
var<workgroup> row_max_shmem: array<f32, Q_TILE>;
|
||||
var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
|
||||
|
||||
fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 {
|
||||
var v = select(FLOAT_MIN,
|
||||
f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
|
||||
kv_idx < KV_TILE);
|
||||
#ifdef LOGIT_SOFTCAP
|
||||
v = params.logit_softcap * tanh(v);
|
||||
#endif
|
||||
#ifdef MASK
|
||||
let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
|
||||
let mask_term = slope * mask_val;
|
||||
v += mask_term;
|
||||
#endif
|
||||
return v;
|
||||
}
|
||||
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||
|
||||
// initialize row max for online softmax
|
||||
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
|
||||
row_max_shmem[i] = FLOAT_MIN;
|
||||
exp_sum_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
|
||||
o_shmem[i] = 0.0;
|
||||
}
|
||||
|
||||
// workgroups per head/batch
|
||||
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
|
||||
let wg_per_batch = wg_per_head * params.n_heads;
|
||||
|
||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||
let dst3_stride = dst2_stride * params.seq_len_q;
|
||||
|
||||
// batch index
|
||||
let batch_idx = wg_id.x / wg_per_batch;
|
||||
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
|
||||
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
|
||||
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
|
||||
let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;
|
||||
let wg_in_batch = wg_id.x % wg_per_batch;
|
||||
|
||||
// head index
|
||||
let head_idx = wg_in_batch / wg_per_head;
|
||||
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
|
||||
let k_head_idx = head_idx / params.q_per_kv;
|
||||
let v_head_idx = k_head_idx;
|
||||
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
|
||||
let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
|
||||
|
||||
// starting Q row for this workgroup
|
||||
let wg_in_head = wg_in_batch % wg_per_head;
|
||||
let q_row_start = wg_in_head * Q_TILE;
|
||||
|
||||
#ifdef MASK
|
||||
// mask offset
|
||||
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
|
||||
#endif
|
||||
|
||||
// note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size]
|
||||
let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;
|
||||
|
||||
let head = f32(head_idx);
|
||||
let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0);
|
||||
|
||||
// load q tile into shared memory
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let q_row = elem_idx / HEAD_DIM_QK;
|
||||
let q_col = elem_idx % HEAD_DIM_QK;
|
||||
let head_q_row = q_row_start + q_row;
|
||||
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
||||
q_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
Q[global_q_row_offset + q_col],
|
||||
head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
|
||||
}
|
||||
|
||||
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
|
||||
// clear inter_shmem to ensure zero-initialized accumulators
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
inter_shmem[elem_idx] = 0.0;
|
||||
}
|
||||
|
||||
// load k tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx]; // scale
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx]; // scale
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let k_row = elem_idx / HEAD_DIM_QK;
|
||||
let k_col = elem_idx % HEAD_DIM_QK;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
|
||||
kv_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
K[global_k_row_offset + k_col],
|
||||
global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK));
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// accumulate q block * k block into registers across the entire KV tile
|
||||
// TODO: this loop seems to be the current largest bottleneck
|
||||
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
|
||||
let inter_offset = kv_block * SG_MAT_N;
|
||||
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<
|
||||
subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
|
||||
#ifdef KV_DIRECT
|
||||
let k_block_row = kv_tile + kv_block * SG_MAT_N;
|
||||
let k_global_offset = k_head_offset + k_block_row * params.stride_k1;
|
||||
#else
|
||||
let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK;
|
||||
#endif
|
||||
for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) {
|
||||
// load q submatrix from shared memory
|
||||
var q_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
|
||||
&q_shmem,
|
||||
head_dim_block,
|
||||
false,
|
||||
HEAD_DIM_QK
|
||||
);
|
||||
|
||||
// load k submatrix from device or shared memory
|
||||
#ifdef KV_DIRECT
|
||||
var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
||||
&K,
|
||||
k_global_offset + head_dim_block,
|
||||
true,
|
||||
params.stride_k1
|
||||
);
|
||||
#else
|
||||
var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
||||
&kv_shmem,
|
||||
k_block_offset + head_dim_block,
|
||||
true,
|
||||
HEAD_DIM_QK
|
||||
);
|
||||
#endif
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc);
|
||||
}
|
||||
|
||||
// store acc to shared memory for softmax (S matrix from paper)
|
||||
subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
|
||||
}
|
||||
|
||||
#ifdef MASK
|
||||
// load mask tile into shared memory for this KV block
|
||||
// TODO: optimize and skip if mask is -INF for the entire tile
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
let mask_row = elem_idx / KV_TILE;
|
||||
let mask_col = elem_idx % KV_TILE;
|
||||
let global_q_row = q_row_start + mask_row;
|
||||
let global_k_col = kv_tile + mask_col;
|
||||
let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
|
||||
let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
|
||||
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// online softmax
|
||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
|
||||
// initialize running max for this row
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
var final_max = prev_max;
|
||||
// pass 1: compute final max across the full KV tile in chunks
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
|
||||
final_max = subgroupMax(max(final_max, softmax_term));
|
||||
}
|
||||
|
||||
var total_exp_term: f32 = 0.0;
|
||||
// pass 2: compute exp sum and write P using final_max
|
||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||
let kv_idx = kv_offset + sg_inv_id;
|
||||
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
|
||||
let cur_p = select(0.0,
|
||||
exp(softmax_term - final_max),
|
||||
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
|
||||
total_exp_term += subgroupAdd(cur_p);
|
||||
if (kv_idx < KV_TILE) {
|
||||
inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
|
||||
}
|
||||
}
|
||||
|
||||
let cur_exp = exp(prev_max - final_max);
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
row_max_shmem[q_tile_row] = final_max;
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
|
||||
}
|
||||
}
|
||||
|
||||
// load v tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx]; // scale
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx]; // scale
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) {
|
||||
let v_row = elem_idx / HEAD_DIM_V;
|
||||
let v_col = elem_idx % HEAD_DIM_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
|
||||
kv_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
V[global_v_row_offset + v_col],
|
||||
global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V));
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
|
||||
// we want to compute O += P * V across the full KV tile
|
||||
for (var head_dim_block = subgroup_id * SG_MAT_N;
|
||||
head_dim_block < HEAD_DIM_V;
|
||||
head_dim_block += num_subgroups * SG_MAT_N) {
|
||||
// load O submatrix from shared memory
|
||||
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(
|
||||
&o_shmem,
|
||||
head_dim_block,
|
||||
false,
|
||||
HEAD_DIM_V
|
||||
);
|
||||
|
||||
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
|
||||
let p_offset = kv_block * SG_MAT_N;
|
||||
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
|
||||
&inter_shmem,
|
||||
p_offset,
|
||||
false,
|
||||
KV_TILE
|
||||
);
|
||||
|
||||
// load V submatrix from global or shared memory
|
||||
#ifdef KV_DIRECT
|
||||
let v_block_row = kv_tile + kv_block * SG_MAT_N;
|
||||
let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
|
||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
||||
&V,
|
||||
v_global_offset,
|
||||
false,
|
||||
params.stride_v1
|
||||
);
|
||||
#else
|
||||
let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
|
||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
||||
&kv_shmem,
|
||||
v_block_offset + head_dim_block,
|
||||
false,
|
||||
HEAD_DIM_V
|
||||
);
|
||||
#endif
|
||||
// O += P * V
|
||||
o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);
|
||||
}
|
||||
|
||||
// store O back to shared memory
|
||||
subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
#ifdef SINKS
|
||||
// add sinks (applied once after processing all KV tiles)
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
// no need to process rows beyond seq_len_q
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
|
||||
var prev_max = row_max_shmem[q_tile_row];
|
||||
|
||||
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
|
||||
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
|
||||
let new_max = subgroupMax(max(prev_max, sink_val));
|
||||
let max_exp = exp(prev_max - new_max);
|
||||
let sink_exp = exp(sink_val - new_max);
|
||||
|
||||
let sink_exp_sum = subgroupAdd(sink_exp);
|
||||
|
||||
if (sg_inv_id == 0) {
|
||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
|
||||
}
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
||||
let val = f32(o_shmem[idx]) * max_exp;
|
||||
o_shmem[idx] = f16(val);
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
#endif
|
||||
|
||||
// write output back to global memory
|
||||
for (var q_tile_row = subgroup_id;
|
||||
q_tile_row < Q_TILE;
|
||||
q_tile_row += num_subgroups) {
|
||||
let global_q_row = q_row_start + q_tile_row;
|
||||
if (global_q_row >= params.seq_len_q) {
|
||||
break;
|
||||
}
|
||||
|
||||
let exp_sum = exp_sum_shmem[q_tile_row];
|
||||
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0);
|
||||
|
||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||
let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx];
|
||||
let scaled = f32(o_val) * scale;
|
||||
dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1292,7 +1292,9 @@ extern "C" {
|
||||
// available samplers:
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
||||
|
||||
/// seed == LLAMA_DEFAULT_SEED to use a random seed.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed);
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
/// Setting k <= 0 makes this a noop
|
||||
|
||||
@@ -2452,6 +2452,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
pimpl->gpu_buft_list.emplace(dev, std::move(buft_list));
|
||||
}
|
||||
|
||||
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (cpu_dev == nullptr) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
|
||||
// calculate the split points
|
||||
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; });
|
||||
std::vector<float> splits(n_devices());
|
||||
@@ -2462,6 +2467,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
size_t total;
|
||||
size_t free;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
|
||||
// devices can return 0 bytes for free and total memory if they do not
|
||||
// have any to report. in this case, we will use the host memory as a fallback
|
||||
// fixes: https://github.com/ggml-org/llama.cpp/issues/18577
|
||||
if (free == 0 && total == 0) {
|
||||
ggml_backend_dev_memory(cpu_dev, &free, &total);
|
||||
}
|
||||
splits[i] = free;
|
||||
}
|
||||
} else {
|
||||
@@ -2478,10 +2490,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
splits[i] /= split_sum;
|
||||
}
|
||||
|
||||
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (cpu_dev == nullptr) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0);
|
||||
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1);
|
||||
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
|
||||
|
||||
@@ -2142,7 +2142,7 @@ struct llama_sampler_xtc {
|
||||
const uint32_t seed;
|
||||
uint32_t seed_cur;
|
||||
|
||||
std::mt19937 rng;
|
||||
std::mt19937 rng;
|
||||
};
|
||||
|
||||
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
||||
@@ -111,8 +111,20 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < ret.size(); i++) {
|
||||
size_t free, total;
|
||||
size_t free;
|
||||
size_t total;
|
||||
ggml_backend_dev_memory(model->devices[i], &free, &total);
|
||||
|
||||
// devices can return 0 bytes for free and total memory if they do not
|
||||
// have any to report. in this case, we will use the host memory as a fallback
|
||||
// fixes: https://github.com/ggml-org/llama.cpp/issues/18577
|
||||
if (free == 0 && total == 0) {
|
||||
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (cpu_dev == nullptr) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
ggml_backend_dev_memory(cpu_dev, &free, &total);
|
||||
}
|
||||
ret[i].free = free;
|
||||
ret[i].total = total;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
@@ -16,7 +15,6 @@
|
||||
#include <cstddef>
|
||||
#include <cinttypes>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <filesystem>
|
||||
|
||||
// fix problem with std::min and std::max
|
||||
@@ -2927,9 +2925,14 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
if (task.params.n_cmpl > 1) {
|
||||
task.n_children = task.params.n_cmpl - 1;
|
||||
for (size_t j = 0; j < task.n_children; j++) {
|
||||
server_task child = task.create_child(
|
||||
task.id,
|
||||
rd.get_new_id());
|
||||
server_task child = task.create_child(task.id, rd.get_new_id());
|
||||
|
||||
// use different sampling seed for each child
|
||||
// note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
|
||||
if (child.params.sampling.seed != LLAMA_DEFAULT_SEED) {
|
||||
child.params.sampling.seed += j + 1;
|
||||
}
|
||||
|
||||
tasks.push_back(std::move(child));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -503,5 +503,4 @@ def test_chat_completions_multiple_choices():
|
||||
assert len(res.body["choices"]) == 2
|
||||
for choice in res.body["choices"]:
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex("Suddenly", choice["message"]["content"])
|
||||
assert choice["finish_reason"] == "length"
|
||||
|
||||
Reference in New Issue
Block a user